/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.gradient_boosting;

import java.util.ArrayList;
import java.util.List;
import java.util.function.IntFunction;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;
import sklearn.HasPriorProbability;
import sklearn.SkLearnClassifier;
import sklearn.VersionUtil;
import sklearn.ensemble.gradient_boosting.GradientBoostingUtil;
import sklearn.ensemble.gradient_boosting.LossFunction;
import sklearn.tree.HasTreeOptions;
import sklearn.tree.TreeRegressor;
import sklearn2pmml.EstimatorProxy;

public class GradientBoostingClassifier
extends SkLearnClassifier
implements HasEstimatorEnsemble<TreeRegressor>,
HasTreeOptions {
    public GradientBoostingClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        if (this.containsKey("n_features")) {
            return this.getInteger("n_features");
        }
        return super.getNumberOfFeatures();
    }

    @Override
    public DataType getDataType() {
        return DataType.FLOAT;
    }

    public MiningModel encodeModel(Schema schema) {
        MiningModel miningModel;
        String sklearnVersion = this.getSkLearnVersion();
        LossFunction loss = this.getLoss();
        int numberOfClasses = loss.getK();
        HasPriorProbability init = this.getInit();
        Number learningRate = this.getLearningRate();
        IntFunction<Number> initialPredictions = init::getPriorProbability;
        if (sklearnVersion != null && VersionUtil.compareVersion(sklearnVersion, "0.21") >= 0) {
            List<? extends Number> computedInitialPredictions = loss.computeInitialPredictions(init);
            initialPredictions = computedInitialPredictions::get;
        }
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (numberOfClasses == 1) {
            SchemaUtil.checkSize((int)2, (DiscreteLabel)categoricalLabel);
            MiningModel model = GradientBoostingUtil.encodeGradientBoosting(this, initialPredictions.apply(1), learningRate, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)FieldNameUtil.create((String)"decisionFunction", (Object[])new Object[]{categoricalLabel.getValue(1)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[]{loss.createTransformation()}));
            miningModel = MiningModelUtil.createBinaryLogisticClassification((Model)model, (double)1.0, (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.NONE, (boolean)false, (Schema)schema);
        } else if (numberOfClasses >= 3) {
            SchemaUtil.checkSize((int)numberOfClasses, (DiscreteLabel)categoricalLabel);
            List<? extends TreeRegressor> estimators = this.getEstimators();
            ArrayList<MiningModel> models = new ArrayList<MiningModel>();
            int columns = categoricalLabel.size();
            int rows = estimators.size() / columns;
            for (int i = 0; i < columns; ++i) {
                final List columnEstimators = CMatrixUtil.getColumn(estimators, (int)rows, (int)columns, (int)i);
                GradientBoostingClassifierProxy estimatorProxy = new GradientBoostingClassifierProxy(){

                    @Override
                    public List<? extends TreeRegressor> getEstimators() {
                        return columnEstimators;
                    }
                };
                MiningModel model = GradientBoostingUtil.encodeGradientBoosting(estimatorProxy, initialPredictions.apply(i), learningRate, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)FieldNameUtil.create((String)"decisionFunction", (Object[])new Object[]{categoricalLabel.getValue(i)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[]{loss.createTransformation()}));
                models.add(model);
            }
            miningModel = MiningModelUtil.createClassification(models, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)false, (Schema)schema);
        } else {
            throw new IllegalArgumentException();
        }
        this.encodePredictProbaOutput((Model)miningModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        return miningModel;
    }

    public LossFunction getLoss() {
        if (this.containsKey("loss_")) {
            return (LossFunction)((Object)this.get("loss_", LossFunction.class));
        }
        return (LossFunction)((Object)this.get("_loss", LossFunction.class));
    }

    public HasPriorProbability getInit() {
        return (HasPriorProbability)this.get("init_", HasPriorProbability.class);
    }

    public Number getLearningRate() {
        return this.getNumber("learning_rate");
    }

    @Override
    public List<? extends TreeRegressor> getEstimators() {
        return this.getArray("estimators_", TreeRegressor.class);
    }

    private abstract class GradientBoostingClassifierProxy
    extends EstimatorProxy
    implements HasEstimatorEnsemble<TreeRegressor>,
    HasTreeOptions {
        private GradientBoostingClassifierProxy() {
        }

        @Override
        public Estimator getEstimator() {
            return GradientBoostingClassifier.this;
        }
    }
}

