/*
 * Decompiled with CFR 0.152.
 */
package sklearn.dummy;

import com.google.common.primitives.Doubles;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ScoreDistributionManager;
import org.jpmml.python.ClassDictUtil;
import sklearn.HasPriorProbability;
import sklearn.HasSparseOutput;
import sklearn.SkLearnClassifier;

public class DummyClassifier
extends SkLearnClassifier
implements HasPriorProbability,
HasSparseOutput {
    public DummyClassifier() {
        this("sklearn.dummy", "DummyClassifier");
    }

    public DummyClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public Number getPriorProbability(int index) {
        List<?> classes = this.getClasses();
        List<? extends Number> classPrior = this.getClassPrior();
        String strategy = this.getStrategy();
        ClassDictUtil.checkSize((Collection[])new Collection[]{classes, classPrior});
        switch (strategy) {
            case "prior": {
                return classPrior.get(index);
            }
        }
        throw new IllegalArgumentException(strategy);
    }

    public TreeModel encodeModel(Schema schema) {
        List<? extends Number> probabilities;
        int maxIndex;
        List<?> classes = this.getClasses();
        List<? extends Number> classPrior = this.getClassPrior();
        Object constant = this.getConstant();
        String strategy = this.getStrategy();
        ClassDictUtil.checkSize((Collection[])new Collection[]{classes, classPrior});
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        switch (strategy) {
            case "constant": {
                maxIndex = classes.indexOf(constant);
                if (maxIndex < 0) {
                    throw new IllegalArgumentException();
                }
                probabilities = DummyClassifier.createProbabilities(classes, maxIndex);
                break;
            }
            case "most_frequent": {
                maxIndex = ScoreDistributionManager.indexOfMax(classPrior);
                probabilities = DummyClassifier.createProbabilities(classes, maxIndex);
                break;
            }
            case "prior": {
                maxIndex = ScoreDistributionManager.indexOfMax(classPrior);
                probabilities = classPrior;
                break;
            }
            default: {
                throw new IllegalArgumentException(strategy);
            }
        }
        ClassifierNode root = new ClassifierNode(categoricalLabel.getValue(maxIndex), (Predicate)True.INSTANCE);
        ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
        scoreDistributionManager.addScoreDistributions((PMMLObject)root, categoricalLabel.getValues(), null, probabilities);
        TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), (Node)root);
        this.encodePredictProbaOutput((Model)treeModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        return treeModel;
    }

    public List<? extends Number> getClassPrior() {
        return this.getNumberArray("class_prior_");
    }

    public Object getConstant() {
        return this.getOptionalScalar("constant");
    }

    @Override
    public Boolean getSparseOutput() {
        return this.getBoolean("sparse_output_");
    }

    public String getStrategy() {
        return this.getString("strategy");
    }

    private static List<Double> createProbabilities(List<?> classes, int index) {
        double[] values = new double[classes.size()];
        values[index] = 1.0;
        return Doubles.asList((double[])values);
    }
}

