/*
 * Decompiled with CFR 0.152.
 */
package sklearn.dummy;

import com.google.common.primitives.Doubles;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Classifier;

public class DummyClassifier
extends Classifier {
    public DummyClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        return -1;
    }

    public TreeModel encodeModel(Schema schema) {
        double[] probabilities;
        int index;
        List<?> classes = this.getClasses();
        List<? extends Number> classPrior = this.getClassPrior();
        Object constant = this.getConstant();
        String strategy = this.getStrategy();
        ClassDictUtil.checkSize(classes, classPrior);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        switch (strategy) {
            case "constant": {
                index = classes.indexOf(constant);
                probabilities = new double[classes.size()];
                probabilities[index] = 1.0;
                break;
            }
            case "most_frequent": {
                index = classPrior.indexOf(Collections.max(classPrior));
                probabilities = new double[classes.size()];
                probabilities[index] = 1.0;
                break;
            }
            case "prior": {
                index = classPrior.indexOf(Collections.max(classPrior));
                probabilities = Doubles.toArray(classPrior);
                break;
            }
            default: {
                throw new IllegalArgumentException(strategy);
            }
        }
        Node root = new Node().setPredicate((Predicate)new True()).setScore(ValueUtil.formatValue(classes.get(index)));
        for (int i = 0; i < classes.size(); ++i) {
            ScoreDistribution scoreDistribution = new ScoreDistribution(ValueUtil.formatValue(classes.get(i)), probabilities[i]);
            root.addScoreDistributions(new ScoreDistribution[]{scoreDistribution});
        }
        TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), root).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
        return treeModel;
    }

    public List<? extends Number> getClassPrior() {
        return ClassDictUtil.getArray(this, "class_prior_");
    }

    public Object getConstant() {
        return this.get("constant");
    }

    public String getStrategy() {
        return (String)this.get("strategy");
    }
}

