/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.ensemble;

import com.google.common.collect.Iterables;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Classifier;
import sklearn.linear_model.LinearClassifier;
import sklearn.preprocessing.MultiOneHotEncoder;
import sklearn2pmml.ensemble.GBDTUtil;

public class GBDTLRClassifier
extends Classifier {
    public GBDTLRClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public List<?> getClasses() {
        Classifier gbdt = this.getGBDT();
        return gbdt.getClasses();
    }

    @Override
    public boolean hasProbabilityDistribution() {
        LinearClassifier lr = this.getLR();
        return lr.hasProbabilityDistribution();
    }

    public MiningModel encodeModel(Schema schema) {
        Classifier gbdt = this.getGBDT();
        MultiOneHotEncoder ohe = this.getOHE();
        LinearClassifier lr = this.getLR();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        SchemaUtil.checkSize((int)2, (DiscreteLabel)categoricalLabel);
        List<? extends Number> coef = lr.getCoef();
        List<? extends Number> intercept = lr.getIntercept();
        Schema segmentSchema = schema.toAnonymousSchema();
        MiningModel model = GBDTUtil.encodeModel(gbdt, ohe, coef, (Number)Iterables.getOnlyElement(intercept), segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)"decisionFunction", (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
        MiningModel miningModel = MiningModelUtil.createBinaryLogisticClassification((Model)model, (double)1.0, (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)false, (Schema)schema);
        if (lr.hasProbabilityDistribution()) {
            this.encodePredictProbaOutput((Model)miningModel, DataType.DOUBLE, categoricalLabel);
        }
        return miningModel;
    }

    public Classifier getGBDT() {
        return (Classifier)this.get("gbdt_", Classifier.class);
    }

    public LinearClassifier getLR() {
        return (LinearClassifier)this.get("lr_", LinearClassifier.class);
    }

    public MultiOneHotEncoder getOHE() {
        return (MultiOneHotEncoder)this.get("ohe_", MultiOneHotEncoder.class);
    }
}

