/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.gradient_boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.MatrixUtil;
import sklearn.Classifier;
import sklearn.ensemble.gradient_boosting.GradientBoostingUtil;
import sklearn.ensemble.gradient_boosting.HasPriorProbability;
import sklearn.ensemble.gradient_boosting.LossFunction;
import sklearn.tree.DecisionTreeRegressor;

public class GradientBoostingClassifier
extends Classifier {
    public GradientBoostingClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        return ValueUtil.asInt((Number)((Number)this.get("n_features")));
    }

    @Override
    public boolean requiresContinuousInput() {
        return false;
    }

    @Override
    public DataType getDataType() {
        return DataType.FLOAT;
    }

    public MiningModel encodeModel(Schema schema) {
        LossFunction loss = this.getLoss();
        int numberOfClasses = loss.getK();
        HasPriorProbability init = this.getInit();
        Number learningRate = this.getLearningRate();
        List<DecisionTreeRegressor> estimators = this.getEstimators();
        List targetCategories = schema.getTargetCategories();
        Schema segmentSchema = schema.toAnonymousSchema();
        if (numberOfClasses == 1) {
            if (targetCategories.size() != 2) {
                throw new IllegalArgumentException();
            }
            double coefficient = loss.getCoefficient();
            MiningModel miningModel = GradientBoostingClassifier.encodeCategoryRegressor((String)targetCategories.get(1), estimators, init.getPriorProbability(0), learningRate, null, segmentSchema);
            return MiningModelUtil.createBinaryLogisticClassification((Schema)schema, (Model)miningModel, (double)coefficient, (boolean)true);
        }
        if (numberOfClasses >= 2) {
            if (targetCategories.size() != numberOfClasses) {
                throw new IllegalArgumentException();
            }
            ArrayList<MiningModel> miningModels = new ArrayList<MiningModel>();
            for (int i = 0; i < targetCategories.size(); ++i) {
                MiningModel miningModel = GradientBoostingClassifier.encodeCategoryRegressor((String)targetCategories.get(i), MatrixUtil.getColumn(estimators, estimators.size() / numberOfClasses, numberOfClasses, i), init.getPriorProbability(i), learningRate, loss.getFunction(), segmentSchema);
                miningModels.add(miningModel);
            }
            return MiningModelUtil.createClassification((Schema)schema, miningModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)true);
        }
        throw new IllegalArgumentException();
    }

    @Override
    public Set<DefineFunction> encodeDefineFunctions() {
        LossFunction loss = this.getLoss();
        DefineFunction defineFunction = loss.encodeFunction();
        if (defineFunction != null) {
            return Collections.singleton(defineFunction);
        }
        return super.encodeDefineFunctions();
    }

    public LossFunction getLoss() {
        Object loss = this.get("loss_");
        try {
            if (loss == null) {
                throw new NullPointerException();
            }
            return (LossFunction)((Object)loss);
        }
        catch (RuntimeException re) {
            throw new IllegalArgumentException("The loss function object (" + ClassDictUtil.formatClass(loss) + ") is not a LossFunction or is not a supported LossFunction subclass", re);
        }
    }

    public HasPriorProbability getInit() {
        Object init = this.get("init_");
        try {
            if (init == null) {
                throw new NullPointerException();
            }
            return (HasPriorProbability)init;
        }
        catch (RuntimeException re) {
            throw new IllegalArgumentException("The estimator object (" + ClassDictUtil.formatClass(init) + ") is not a BaseEstimator or is not a supported BaseEstimator subclass", re);
        }
    }

    public Number getLearningRate() {
        return (Number)this.get("learning_rate");
    }

    public List<DecisionTreeRegressor> getEstimators() {
        return ClassDictUtil.getArray(this, "estimators_");
    }

    private static MiningModel encodeCategoryRegressor(String targetCategory, List<DecisionTreeRegressor> estimators, Number priorProbability, Number learningRate, String outputTransformation, Schema schema) {
        OutputField decisionFunction = new OutputField(FieldName.create((String)("decisionFunction_" + targetCategory)), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.PREDICTED_VALUE).setFinalResult(Boolean.valueOf(false));
        Output output = new Output().addOutputFields(new OutputField[]{decisionFunction});
        if (outputTransformation != null) {
            OutputField transformedDecisionField = new OutputField(FieldName.create((String)(outputTransformation + "DecisionFunction_" + targetCategory)), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(Boolean.valueOf(false)).setExpression((Expression)PMMLUtil.createApply((String)outputTransformation, (Expression[])new Expression[]{new FieldRef(decisionFunction.getName())}));
            output.addOutputFields(new OutputField[]{transformedDecisionField});
        }
        MiningModel miningModel = GradientBoostingUtil.encodeGradientBoosting(estimators, priorProbability, learningRate, schema).setOutput(output);
        return miningModel;
    }
}

