/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.util.List;
import java.util.Objects;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.TransformerConverter;

public abstract class ModelConverter<T extends Model<T>>
extends TransformerConverter<T> {
    public ModelConverter(T model) {
        super(model);
    }

    public abstract MiningFunction getMiningFunction();

    public abstract List<Feature> getFeatures(SparkMLEncoder var1);

    public abstract org.dmg.pmml.Model encodeModel(Schema var1);

    public Schema encodeSchema(SparkMLEncoder encoder) {
        Label label = this.getLabel(encoder);
        List<Feature> features = this.getFeatures(encoder);
        Schema result = new Schema((PMMLEncoder)encoder, label, features);
        ModelConverter.checkSchema(result);
        return result;
    }

    public Label getLabel(SparkMLEncoder encoder) {
        Model model = (Model)this.getTransformer();
        CategoricalLabel label = null;
        if (model instanceof HasLabelCol) {
            HasLabelCol hasLabelCol = (HasLabelCol)model;
            String labelCol = hasLabelCol.getLabelCol();
            Feature feature = encoder.getOnlyFeature(labelCol);
            MiningFunction miningFunction = this.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    if (feature instanceof BooleanFeature) {
                        BooleanFeature booleanFeature = (BooleanFeature)feature;
                        label = new CategoricalLabel(booleanFeature.getName(), booleanFeature.getDataType(), booleanFeature.getValues());
                        break;
                    }
                    if (feature instanceof CategoricalFeature) {
                        CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                        DataField dataField = (DataField)categoricalFeature.getField();
                        label = new CategoricalLabel((Field)dataField);
                        break;
                    }
                    if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                        int numClasses = 2;
                        if (model instanceof ClassificationModel) {
                            ClassificationModel classificationModel = (ClassificationModel)model;
                            numClasses = classificationModel.numClasses();
                        }
                        List categories = LabelUtil.createTargetCategories((int)numClasses);
                        Field field = encoder.toCategorical(continuousFeature.getName(), categories);
                        encoder.putOnlyFeature(labelCol, (Feature)new IndexFeature((PMMLEncoder)encoder, field, categories));
                        label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
                        break;
                    }
                    throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
                }
                case REGRESSION: {
                    Field field = encoder.toContinuous(feature.getName());
                    field.setDataType(DataType.DOUBLE);
                    label = new ContinuousLabel(field.getName(), field.getDataType());
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
                }
            }
        }
        if (model instanceof ClassificationModel) {
            ClassificationModel classificationModel = (ClassificationModel)model;
            int numClasses = classificationModel.numClasses();
            CategoricalLabel categoricalLabel = label;
            SchemaUtil.checkSize((int)numClasses, (CategoricalLabel)categoricalLabel);
        }
        return label;
    }

    public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model model, SparkMLEncoder encoder) {
        return null;
    }

    public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder) {
        org.dmg.pmml.Model model;
        Schema schema = this.encodeSchema(encoder);
        Label label = schema.getLabel();
        List<OutputField> sparkOutputFields = this.registerOutputFields(label, model = this.encodeModel(schema), encoder);
        if (sparkOutputFields != null && sparkOutputFields.size() > 0) {
            org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel((org.dmg.pmml.Model)model);
            Output output = ModelUtil.ensureOutput((org.dmg.pmml.Model)finalModel);
            List outputFields = output.getOutputFields();
            outputFields.addAll(sparkOutputFields);
        }
        return model;
    }

    private static void checkSchema(Schema schema) {
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        if (label == null) {
            return;
        }
        for (Feature feature : features) {
            if (!Objects.equals(label.getName(), feature.getName())) continue;
            throw new IllegalArgumentException("Label column '" + label.getName() + "' is contained in the list of feature columns");
        }
    }
}

