/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.List;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FeatureType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.TreeModel;
import org.jpmml.converter.MiningModelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.model.TreeModelUtil;

public class GBTClassificationModelConverter
extends ModelConverter<GBTClassificationModel> {
    public GBTClassificationModelConverter(GBTClassificationModel model) {
        super(model);
    }

    public MiningModel encodeModel(Schema schema) {
        GBTClassificationModel model = (GBTClassificationModel)this.getTransformer();
        Schema segmentSchema = schema.toAnonymousSchema();
        List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble((TreeEnsembleModel)model, segmentSchema);
        double[] weights = model.treeWeights();
        for (int i = 0; i < weights.length; ++i) {
            TreeModelUtil.scalePredictions(treeModels.get(i), weights[i]);
        }
        OutputField gbtValue = ModelUtil.createPredictedField((FieldName)FieldName.create((String)"gbtValue"));
        OutputField binarizedGbtValue = new OutputField(FieldName.create((String)"binarizedGbtValue")).setFeature(FeatureType.TRANSFORMED_VALUE).setDataType(DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setExpression((Expression)PMMLUtil.createApply((String)"if", (Expression[])new Expression[]{PMMLUtil.createApply((String)"greaterThan", (Expression[])new Expression[]{new FieldRef(gbtValue.getName()), PMMLUtil.createConstant((Object)0.0)}), PMMLUtil.createConstant((Object)-1.0), PMMLUtil.createConstant((Object)1.0)}));
        Output output = new Output().addOutputFields(new OutputField[]{gbtValue, binarizedGbtValue});
        Segmentation segmentation = MiningModelUtil.createSegmentation((MultipleModelMethodType)MultipleModelMethodType.SUM, treeModels);
        MiningModel miningModel = new MiningModel(MiningFunctionType.REGRESSION, ModelUtil.createMiningSchema((Schema)segmentSchema)).setSegmentation(segmentation).setOutput(output);
        return MiningModelUtil.createBinaryLogisticClassification((Schema)schema, (Model)miningModel, (double)1000.0, (boolean)false);
    }
}

