/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.hist_gradient_boosting;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import sklearn.Classifier;
import sklearn.ensemble.hist_gradient_boosting.BaseLoss;
import sklearn.ensemble.hist_gradient_boosting.BinaryCrossEntropy;
import sklearn.ensemble.hist_gradient_boosting.CategoricalCrossEntropy;
import sklearn.ensemble.hist_gradient_boosting.HistGradientBoostingUtil;
import sklearn.ensemble.hist_gradient_boosting.TreePredictor;

public class HistGradientBoostingClassifier
extends Classifier {
    public HistGradientBoostingClassifier(String module, String name) {
        super(module, name);
    }

    public MiningModel encodeModel(Schema schema) {
        int numberOfTreesPerIteration = this.getNumberOfTreesPerIteration();
        List<List<TreePredictor>> predictors = this.getPredictors();
        List<? extends Number> baselinePredictions = this.getBaselinePrediction();
        BaseLoss loss = this.getLoss();
        if (predictors.size() > 0) {
            ClassDictUtil.checkSize((int)numberOfTreesPerIteration, (Collection[])new Collection[]{predictors.get(0), baselinePredictions});
        }
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (numberOfTreesPerIteration == 1) {
            SchemaUtil.checkSize((int)2, (CategoricalLabel)categoricalLabel);
            if (!(loss instanceof BinaryCrossEntropy)) {
                throw new IllegalArgumentException();
            }
            MiningModel miningModel = HistGradientBoostingUtil.encodeHistGradientBoosting(predictors, baselinePredictions, 0, segmentSchema).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldNameUtil.create((String)"decisionFunction", (Object[])new Object[]{categoricalLabel.getValue(1)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            return MiningModelUtil.createBinaryLogisticClassification((Model)miningModel, (double)1.0, (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)true, (Schema)schema);
        }
        if (numberOfTreesPerIteration >= 3) {
            SchemaUtil.checkSize((int)numberOfTreesPerIteration, (CategoricalLabel)categoricalLabel);
            if (!(loss instanceof CategoricalCrossEntropy)) {
                throw new IllegalArgumentException();
            }
            ArrayList<MiningModel> miningModels = new ArrayList<MiningModel>();
            int columns = categoricalLabel.size();
            for (int i = 0; i < columns; ++i) {
                MiningModel miningModel = HistGradientBoostingUtil.encodeHistGradientBoosting(predictors, baselinePredictions, i, segmentSchema).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldNameUtil.create((String)"decisionFunction", (Object[])new Object[]{categoricalLabel.getValue(i)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
                miningModels.add(miningModel);
            }
            return MiningModelUtil.createClassification(miningModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SOFTMAX, (boolean)true, (Schema)schema);
        }
        throw new IllegalArgumentException();
    }

    public List<? extends Number> getBaselinePrediction() {
        return this.getNumberArray("_baseline_prediction");
    }

    public BaseLoss getLoss() {
        return (BaseLoss)((Object)this.get("loss_", BaseLoss.class));
    }

    public Integer getNumberOfTreesPerIteration() {
        return this.getInteger("n_trees_per_iteration_");
    }

    public List<List<TreePredictor>> getPredictors() {
        return this.getList("_predictors", List.class);
    }
}

