/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.TransformerConverter;

public abstract class ModelConverter<T extends Model<T> & HasPredictionCol>
extends TransformerConverter<T> {
    public ModelConverter(T model) {
        super(model);
    }

    public abstract MiningFunction getMiningFunction();

    public abstract org.dmg.pmml.Model encodeModel(Schema var1);

    public Schema encodeSchema(SparkMLEncoder encoder) {
        PredictionModel predictionModel;
        int numFeatures;
        Model model = (Model)this.getTransformer();
        CategoricalLabel label = null;
        if (model instanceof HasLabelCol) {
            HasLabelCol hasLabelCol = (HasLabelCol)model;
            String labelCol = hasLabelCol.getLabelCol();
            Feature feature = encoder.getOnlyFeature(labelCol);
            MiningFunction miningFunction = this.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                        DataField dataField = encoder.getDataField(categoricalFeature.getName());
                        label = new CategoricalLabel(dataField);
                        break;
                    }
                    if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel classificationModel = (ClassificationModel)model;
                            numClasses = classificationModel.numClasses();
                        }
                        ArrayList<String> categories = new ArrayList<String>();
                        for (int i = 0; i < numClasses; ++i) {
                            categories.add(String.valueOf(i));
                        }
                        TypeDefinitionField field = encoder.toCategorical(continuousFeature.getName(), categories);
                        encoder.putOnlyFeature(labelCol, (Feature)new CategoricalFeature((PMMLEncoder)encoder, field, categories));
                        label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
                        break;
                    }
                    throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                }
                case REGRESSION: {
                    TypeDefinitionField field = encoder.toContinuous(feature.getName());
                    field.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(field.getName(), field.getDataType());
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
                }
            }
        }
        if (model instanceof ClassificationModel) {
            ClassificationModel classificationModel = (ClassificationModel)model;
            CategoricalLabel categoricalLabel = label;
            int numClasses = classificationModel.numClasses();
            if (numClasses != categoricalLabel.size()) {
                throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
            }
        }
        String featuresCol = ((HasFeaturesCol)model).getFeaturesCol();
        List<Feature> features = encoder.getFeatures(featuresCol);
        if (model instanceof PredictionModel && (numFeatures = (predictionModel = (PredictionModel)model).numFeatures()) != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
        Schema result = new Schema(label, features);
        return result;
    }

    public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
        return null;
    }

    public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder) {
        Schema schema = this.encodeSchema(encoder);
        Label label = schema.getLabel();
        org.dmg.pmml.Model model = this.encodeModel(schema);
        List<OutputField> sparkOutputFields = this.registerOutputFields(label, encoder);
        if (sparkOutputFields != null && sparkOutputFields.size() > 0) {
            org.dmg.pmml.Model lastModel = this.getLastModel(model);
            Output output = lastModel.getOutput();
            if (output == null) {
                output = new Output();
                lastModel.setOutput(output);
            }
            List outputFields = output.getOutputFields();
            outputFields.addAll(0, sparkOutputFields);
        }
        return model;
    }

    protected org.dmg.pmml.Model getLastModel(org.dmg.pmml.Model model) {
        if (model instanceof MiningModel) {
            MiningModel miningModel = (MiningModel)model;
            Segmentation segmentation = miningModel.getSegmentation();
            Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
            switch (multipleModelMethod) {
                case MODEL_CHAIN: {
                    List segments = segmentation.getSegments();
                    if (segments.size() <= 0) break;
                    Segment lastSegment = (Segment)segments.get(segments.size() - 1);
                    return lastSegment.getModel();
                }
            }
        }
        return model;
    }
}

