/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import com.google.common.primitives.Doubles;
import java.util.List;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.model.HasFeatureImportances;
import org.jpmml.sparkml.model.HasTreeOptions;
import org.jpmml.sparkml.model.TreeModelUtil;

public class GBTClassificationModelConverter
extends ClassificationModelConverter<GBTClassificationModel>
implements HasFeatureImportances,
HasTreeOptions {
    public GBTClassificationModelConverter(GBTClassificationModel model) {
        super(model);
    }

    @Override
    public Vector getFeatureImportances() {
        GBTClassificationModel model = (GBTClassificationModel)this.getTransformer();
        return model.featureImportances();
    }

    public MiningModel encodeModel(Schema schema) {
        String lossType;
        GBTClassificationModel model = (GBTClassificationModel)this.getTransformer();
        switch (lossType = model.getLossType()) {
            case "logistic": {
                break;
            }
            default: {
                throw new IllegalArgumentException("Loss function " + lossType + " is not supported");
            }
        }
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, segmentSchema);
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)segmentSchema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, (List)Doubles.asList((double[])model.treeWeights()))).setOutput(ModelUtil.createPredictedOutput((String)"gbtValue", (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
        return MiningModelUtil.createBinaryLogisticClassification((Model)miningModel, (double)2.0, (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)false, (Schema)schema);
    }
}

