/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import com.google.common.primitives.Doubles;
import java.util.List;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.AbstractTransformation;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.model.TreeModelUtil;

public class GBTClassificationModelConverter
extends ClassificationModelConverter<GBTClassificationModel> {
    public GBTClassificationModelConverter(GBTClassificationModel model) {
        super(model);
    }

    public MiningModel encodeModel(Schema schema) {
        GBTClassificationModel model = (GBTClassificationModel)this.getTransformer();
        Schema segmentSchema = new Schema((Label)new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
        List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, segmentSchema);
        AbstractTransformation binarizedGbtValue = new AbstractTransformation(){

            public FieldName getName(FieldName name) {
                return 1.withPrefix((FieldName)name, (String)"binarized");
            }

            public Expression createExpression(FieldRef fieldRef) {
                return PMMLUtil.createApply((String)"if", (Expression[])new Expression[]{PMMLUtil.createApply((String)"greaterThan", (Expression[])new Expression[]{fieldRef, PMMLUtil.createConstant((Object)0.0)}), PMMLUtil.createConstant((Object)-1.0), PMMLUtil.createConstant((Object)1.0)});
            }
        };
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Schema)segmentSchema)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, (List)Doubles.asList((double[])model.treeWeights()))).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldName.create((String)"gbtValue"), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[]{binarizedGbtValue}));
        return MiningModelUtil.createBinaryLogisticClassification((Model)miningModel, (double)-1000.0, (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)false, (Schema)schema);
    }
}

