/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.xgboost;

import java.util.Collections;
import java.util.List;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel;
import org.apache.spark.ml.PredictionModel;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.PredictionModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.xgboost.BoosterUtil;

public class XGBoostRegressionModelConverter
extends PredictionModelConverter<XGBoostRegressionModel> {
    public XGBoostRegressionModelConverter(XGBoostRegressionModel model) {
        super((PredictionModel)model);
    }

    public MiningFunction getMiningFunction() {
        return MiningFunction.REGRESSION;
    }

    public MiningModel encodeModel(Schema schema) {
        XGBoostRegressionModel model = (XGBoostRegressionModel)this.getModel();
        Booster booster = model.nativeBooster();
        return BoosterUtil.encodeBooster(this, booster, schema);
    }

    public List<OutputField> registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder) {
        XGBoostRegressionModel model = (XGBoostRegressionModel)this.getModel();
        ScalarLabel scalarLabel = (ScalarLabel)label;
        String predictionCol = model.getPredictionCol();
        Boolean keepPredictionCol = (Boolean)this.getOption("keep_predictionCol", Boolean.TRUE);
        OutputField predictedOutputField = ModelUtil.createPredictedField((String)predictionCol, (OpType)OpType.CONTINUOUS, (DataType)scalarLabel.getDataType());
        DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol.booleanValue());
        encoder.putOnlyFeature(predictionCol, (Feature)new ContinuousFeature((PMMLEncoder)encoder, (Field)predictedField));
        return Collections.emptyList();
    }
}

