/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;

public abstract class ClassificationModelConverter<T extends PredictionModel<Vector, T> & HasFeaturesCol>
extends ModelConverter<T> {
    public ClassificationModelConverter(T model) {
        super(model);
    }

    @Override
    public MiningFunction getMiningFunction() {
        return MiningFunction.CLASSIFICATION;
    }

    @Override
    public List<OutputField> registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder) {
        PredictionModel model = (PredictionModel)this.getTransformer();
        CategoricalLabel categoricalLabel = (CategoricalLabel)label;
        List categories = LabelUtil.createTargetCategories((int)categoricalLabel.size());
        String predictionCol = model.getPredictionCol();
        Boolean keepPredictionCol = (Boolean)this.getOption("keep_predictionCol", Boolean.TRUE);
        OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField((FieldName)FieldName.create((String)("pmml(" + predictionCol + ")")), (OpType)OpType.CATEGORICAL, (DataType)categoricalLabel.getDataType()).setFinalResult(Boolean.valueOf(false));
        DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, keepPredictionCol);
        MapValues mapValues = PMMLUtil.createMapValues((FieldName)pmmlPredictedField.getName(), (List)categoricalLabel.getValues(), (List)categories).setDataType(DataType.DOUBLE);
        OutputField predictedOutputField = new OutputField(FieldName.create((String)predictionCol), OpType.CONTINUOUS, DataType.DOUBLE).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression((Expression)mapValues);
        DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol);
        encoder.putOnlyFeature(predictionCol, (Feature)new IndexFeature((PMMLEncoder)encoder, (Field)predictedField, categories));
        ArrayList<OutputField> result = new ArrayList<OutputField>();
        if (model instanceof HasProbabilityCol) {
            HasProbabilityCol hasProbabilityCol = (HasProbabilityCol)model;
            String probabilityCol = hasProbabilityCol.getProbabilityCol();
            ArrayList<Feature> features = new ArrayList<Feature>();
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                Object value = categoricalLabel.getValue(i);
                OutputField probabilityField = ModelUtil.createProbabilityField((FieldName)FieldName.create((String)(probabilityCol + "(" + value + ")")), (DataType)DataType.DOUBLE, (Object)value);
                result.add(probabilityField);
                features.add((Feature)new ContinuousFeature((PMMLEncoder)encoder, (Field)probabilityField));
            }
            encoder.putFeatures(probabilityCol, features);
        }
        return result;
    }
}

