/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.TransformerConverter;

public abstract class ModelConverter<T extends Model<T> & HasPredictionCol>
extends TransformerConverter<T> {
    public ModelConverter(T model) {
        super(model);
    }

    public abstract MiningFunction getMiningFunction();

    public abstract org.dmg.pmml.Model encodeModel(Schema var1);

    public Schema encodeSchema(SparkMLEncoder encoder) {
        PredictionModel predictionModel;
        int numFeatures;
        Model model = (Model)this.getTransformer();
        ContinuousLabel label = null;
        if (model instanceof HasLabelCol) {
            HasLabelCol hasLabelCol = (HasLabelCol)model;
            String labelCol = hasLabelCol.getLabelCol();
            Feature feature = encoder.getOnlyFeature(labelCol);
            MiningFunction miningFunction = this.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    DataField dataField;
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                        dataField = encoder.getDataField(categoricalFeature.getName());
                    } else if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel classificationModel = (ClassificationModel)model;
                            numClasses = classificationModel.numClasses();
                        }
                        ArrayList<String> categories = new ArrayList<String>();
                        for (int i = 0; i < numClasses; ++i) {
                            categories.add(String.valueOf(i));
                        }
                        dataField = encoder.toCategorical(continuousFeature.getName(), categories);
                        CategoricalFeature categoricalFeature = new CategoricalFeature((PMMLEncoder)encoder, dataField);
                        encoder.putFeatures(labelCol, Collections.singletonList(categoricalFeature));
                    } else {
                        throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                    }
                    label = new CategoricalLabel(dataField);
                    break;
                }
                case REGRESSION: {
                    DataField dataField = encoder.toContinuous(feature.getName());
                    dataField.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(dataField);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
                }
            }
        }
        if (model instanceof ClassificationModel) {
            ClassificationModel classificationModel = (ClassificationModel)model;
            CategoricalLabel categoricalLabel = (CategoricalLabel)label;
            int numClasses = classificationModel.numClasses();
            if (numClasses != categoricalLabel.size()) {
                throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
            }
        }
        HasFeaturesCol hasFeaturesCol = (HasFeaturesCol)model;
        String featuresCol = hasFeaturesCol.getFeaturesCol();
        List<Feature> features = encoder.getFeatures(featuresCol);
        if (model instanceof PredictionModel && (numFeatures = (predictionModel = (PredictionModel)model).numFeatures()) != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
        Schema result = new Schema(label, features);
        return result;
    }

    public Output encodeOutput(Label label, SparkMLEncoder encoder) {
        return null;
    }

    public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder) {
        Schema schema = this.encodeSchema(encoder);
        Label label = schema.getLabel();
        Output output = this.encodeOutput(label, encoder);
        org.dmg.pmml.Model model = this.encodeModel(schema).setOutput(output);
        return model;
    }
}

