/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.sparkml.PredictionModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;

public abstract class ClassificationModelConverter<T extends ClassificationModel<Vector, T>>
extends PredictionModelConverter<T> {
    public ClassificationModelConverter(T model) {
        super(model);
    }

    public int getNumberOfClasses() {
        ClassificationModel model = (ClassificationModel)this.getModel();
        return model.numClasses();
    }

    @Override
    public MiningFunction getMiningFunction() {
        return MiningFunction.CLASSIFICATION;
    }

    @Override
    public void checkSchema(Schema schema) {
        super.checkSchema(schema);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        SchemaUtil.checkSize((int)this.getNumberOfClasses(), (DiscreteLabel)categoricalLabel);
    }

    @Override
    public List<OutputField> registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder) {
        ClassificationModel model = (ClassificationModel)this.getModel();
        CategoricalLabel categoricalLabel = (CategoricalLabel)label;
        List categories = LabelUtil.createTargetCategories((int)categoricalLabel.size());
        String predictionCol = model.getPredictionCol();
        Boolean keepPredictionCol = (Boolean)this.getOption("keep_predictionCol", Boolean.TRUE);
        OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField((String)FieldNameUtil.create((String)"pmml", (Object[])new Object[]{predictionCol}), (OpType)OpType.CATEGORICAL, (DataType)categoricalLabel.getDataType()).setFinalResult(Boolean.valueOf(false));
        DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, keepPredictionCol);
        MapValues mapValues = PMMLUtil.createMapValues((String)pmmlPredictedField.getName(), (List)categoricalLabel.getValues(), (List)categories).setDataType(DataType.DOUBLE);
        OutputField predictedOutputField = new OutputField(predictionCol, OpType.CONTINUOUS, DataType.DOUBLE).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression((Expression)mapValues);
        DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol);
        encoder.putOnlyFeature(predictionCol, (Feature)new IndexFeature((PMMLEncoder)encoder, (Field)predictedField, categories));
        return Collections.emptyList();
    }
}

