/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.xml.parsers.DocumentBuilder;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldColumnPair;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Row;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DOMUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;

public abstract class ClassificationModelConverter<T extends PredictionModel<Vector, T> & HasFeaturesCol>
extends ModelConverter<T> {
    public ClassificationModelConverter(T model) {
        super(model);
    }

    @Override
    public MiningFunction getMiningFunction() {
        return MiningFunction.CLASSIFICATION;
    }

    @Override
    public Output encodeOutput(Label label, SparkMLEncoder encoder) {
        PredictionModel model = (PredictionModel)this.getTransformer();
        CategoricalLabel categoricalLabel = (CategoricalLabel)label;
        PredictionModel hasPredictionCol = model;
        String predictionCol = hasPredictionCol.getPredictionCol();
        OutputField pmmlPredictedField = ModelUtil.createPredictedField((FieldName)FieldName.create((String)("pmml(" + predictionCol + ")")), (DataType)DataType.STRING, (OpType)OpType.CATEGORICAL);
        ArrayList<String> categories = new ArrayList<String>();
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        InlineTable inlineTable = new InlineTable();
        List<String> columns = Arrays.asList("input", "output");
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            String value = categoricalLabel.getValue(i);
            String category = String.valueOf(i);
            categories.add(category);
            Row row = DOMUtil.createRow((DocumentBuilder)documentBuilder, columns, Arrays.asList(value, category));
            inlineTable.addRows(new Row[]{row});
        }
        MapValues mapValues = new MapValues().addFieldColumnPairs(new FieldColumnPair[]{new FieldColumnPair(pmmlPredictedField.getName(), columns.get(0))}).setOutputColumn(columns.get(1)).setInlineTable(inlineTable);
        OutputField predictedField = new OutputField(FieldName.create((String)predictionCol), DataType.DOUBLE).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression((Expression)mapValues);
        Output output = new Output().addOutputFields(new OutputField[]{pmmlPredictedField, predictedField});
        CategoricalFeature feature = new CategoricalFeature((PMMLEncoder)encoder, predictedField.getName(), predictedField.getDataType(), categories){

            public ContinuousFeature toContinuousFeature() {
                PMMLEncoder encoder = this.ensureEncoder();
                return new ContinuousFeature(encoder, this.getName(), this.getDataType());
            }
        };
        encoder.putFeatures(predictionCol, Collections.singletonList(feature));
        if (model instanceof HasProbabilityCol) {
            HasProbabilityCol hasProbabilityCol = (HasProbabilityCol)model;
            String probabilityCol = hasProbabilityCol.getProbabilityCol();
            ArrayList<Feature> features = new ArrayList<Feature>();
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                String value = categoricalLabel.getValue(i);
                OutputField probabilityField = ModelUtil.createProbabilityField((FieldName)FieldName.create((String)(probabilityCol + "(" + value + ")")), (String)value);
                output.addOutputFields(new OutputField[]{probabilityField});
                features.add((Feature)new ContinuousFeature((PMMLEncoder)encoder, probabilityField.getName(), probabilityField.getDataType()));
            }
            encoder.putFeatures(probabilityCol, features);
        }
        return output;
    }
}

