/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.util.List;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.model.HasPredictionModelOptions;

public abstract class PredictionModelConverter<T extends PredictionModel<Vector, T> & HasFeaturesCol>
extends ModelConverter<T>
implements HasPredictionModelOptions {
    public PredictionModelConverter(T model) {
        super(model);
    }

    @Override
    public Label getLabel(SparkMLEncoder encoder) {
        PredictionModel model = (PredictionModel)this.getModel();
        String labelCol = model.getLabelCol();
        Feature feature = encoder.getOnlyFeature(labelCol);
        MiningFunction miningFunction = this.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: {
                if (feature instanceof BooleanFeature) {
                    BooleanFeature booleanFeature = (BooleanFeature)feature;
                    return new CategoricalLabel((CategoricalFeature)booleanFeature);
                }
                if (feature instanceof CategoricalFeature) {
                    CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                    DataField dataField = (DataField)categoricalFeature.getField();
                    return new CategoricalLabel((Field)dataField);
                }
                if (feature instanceof ContinuousFeature) {
                    ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                    int numClasses = 2;
                    if (this instanceof ClassificationModelConverter) {
                        ClassificationModelConverter classificationModelConverter = (ClassificationModelConverter)this;
                        numClasses = classificationModelConverter.getNumberOfClasses();
                    }
                    List categories = LabelUtil.createTargetCategories((int)numClasses);
                    Field field = encoder.toCategorical(continuousFeature.getName(), categories);
                    encoder.putOnlyFeature(labelCol, (Feature)new IndexFeature((PMMLEncoder)encoder, field, categories));
                    return new CategoricalLabel(field.requireName(), field.requireDataType(), categories);
                }
                throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
            }
            case REGRESSION: {
                Field field = encoder.toContinuous(feature.getName());
                field.setDataType(DataType.DOUBLE);
                return new ContinuousLabel(field);
            }
        }
        throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
    }

    @Override
    public List<Feature> getFeatures(SparkMLEncoder encoder) {
        PredictionModel model = (PredictionModel)this.getModel();
        String featuresCol = model.getFeaturesCol();
        List<Feature> features = encoder.getFeatures(featuresCol);
        int numFeatures = model.numFeatures();
        if (numFeatures != -1) {
            SchemaUtil.checkSize((int)numFeatures, features);
        }
        return features;
    }
}

