/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.IntegralType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.PMMLMapper;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sparkml.BooleanFeature;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.ModelConverter;

public class FeatureMapper
extends PMMLMapper {
    private StructType schema = null;
    private Map<String, List<Feature>> columnFeatures = new LinkedHashMap<String, List<Feature>>();

    public FeatureMapper(StructType schema) {
        this.schema = schema;
    }

    public void append(FeatureConverter<?> featureConverter) {
        Object transformer = featureConverter.getTransformer();
        List<Feature> features = featureConverter.encodeFeatures(this);
        if (transformer instanceof HasOutputCol) {
            HasOutputCol hasOutputCol = (HasOutputCol)transformer;
            String outputCol = hasOutputCol.getOutputCol();
            this.putFeatures(outputCol, features);
        }
    }

    public void append(ModelConverter<?> modelConverter) {
        Model model = (Model)modelConverter.getTransformer();
        List<Feature> features = modelConverter.encodeFeatures(this);
        if (model instanceof HasPredictionCol) {
            HasPredictionCol hasPredictionCol = (HasPredictionCol)model;
            String predictionCol = hasPredictionCol.getPredictionCol();
            this.putFeatures(predictionCol, features);
        }
    }

    public Schema createSchema(ModelConverter<?> modelConverter) {
        PredictionModel predictionModel;
        int numFeatures;
        FieldName targetField;
        List<String> targetCategories = null;
        Model model = (Model)modelConverter.getTransformer();
        if (model instanceof PredictionModel) {
            HasLabelCol hasLabelCol = (HasLabelCol)model;
            Feature feature = this.getOnlyFeature(hasLabelCol.getLabelCol());
            targetField = feature.getName();
            MiningFunction miningFunction = modelConverter.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    if (feature instanceof ListFeature) {
                        ListFeature listFeature = (ListFeature)feature;
                        targetCategories = listFeature.getValues();
                        break;
                    }
                    ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                    targetCategories = Arrays.asList("0", "1");
                    DataField dataField = this.toCategorical(targetField, targetCategories);
                    ListFeature listFeature = new ListFeature((TypeDefinitionField)dataField, targetCategories);
                    this.columnFeatures.put(hasLabelCol.getLabelCol(), Collections.singletonList(listFeature));
                    break;
                }
                case REGRESSION: {
                    DataField dataField = this.toContinuous(targetField);
                    dataField.setDataType(DataType.DOUBLE);
                    break;
                }
            }
        } else if (model instanceof KMeansModel) {
            targetField = null;
        } else {
            throw new IllegalArgumentException();
        }
        Map dataFields = this.getDataFields();
        ArrayList activeFields = new ArrayList(dataFields.keySet());
        if (targetField != null) {
            activeFields.remove(targetField);
        }
        HasFeaturesCol hasFeaturesCol = (HasFeaturesCol)model;
        List<Feature> features = this.getFeatures(hasFeaturesCol.getFeaturesCol());
        if (model instanceof PredictionModel && (numFeatures = (predictionModel = (PredictionModel)model).numFeatures()) != -1 && features.size() != numFeatures) {
            throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
        }
        Schema result = new Schema(targetField, (List)targetCategories, activeFields, features);
        return result;
    }

    public DataField toContinuous(FieldName name) {
        DataField dataField = this.getDataField(name);
        if (dataField == null) {
            throw new IllegalArgumentException();
        }
        dataField.setOpType(OpType.CONTINUOUS);
        return dataField;
    }

    public DataField toCategorical(FieldName name, List<String> categories) {
        DataField dataField = this.getDataField(name);
        if (dataField == null) {
            throw new IllegalArgumentException();
        }
        dataField.setOpType(OpType.CATEGORICAL);
        List values = dataField.getValues();
        if (values.size() > 0) {
            throw new IllegalArgumentException();
        }
        values.addAll(PMMLUtil.createValues(categories));
        return dataField;
    }

    public boolean hasFeatures(String column) {
        return this.columnFeatures.containsKey(column);
    }

    public Feature getOnlyFeature(String column) {
        List<Feature> features = this.getFeatures(column);
        return (Feature)Iterables.getOnlyElement(features);
    }

    public List<Feature> getFeatures(String column) {
        List<Feature> features = this.columnFeatures.get(column);
        if (features == null) {
            Object feature;
            FieldName name = FieldName.create((String)column);
            DataField dataField = this.getDataField(name);
            if (dataField == null) {
                dataField = this.createDataField(name);
            }
            DataType dataType = dataField.getDataType();
            switch (dataType) {
                case STRING: {
                    feature = new WildcardFeature((TypeDefinitionField)dataField);
                    break;
                }
                case INTEGER: 
                case DOUBLE: {
                    feature = new ContinuousFeature((TypeDefinitionField)dataField);
                    break;
                }
                case BOOLEAN: {
                    feature = new BooleanFeature((TypeDefinitionField)dataField);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
            return Collections.singletonList(feature);
        }
        return features;
    }

    public List<Feature> getFeatures(String column, int[] indices) {
        List<Feature> features = this.getFeatures(column);
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < indices.length; ++i) {
            int index = indices[i];
            Feature feature = features.get(index);
            result.add(feature);
        }
        return result;
    }

    public void putFeatures(String column, List<Feature> features) {
        this.checkColumn(column);
        this.columnFeatures.put(column, features);
    }

    public DataField createDataField(FieldName name) {
        StructField field = this.schema.apply(name.getValue());
        org.apache.spark.sql.types.DataType sparkDataType = field.dataType();
        if (sparkDataType instanceof StringType) {
            return this.createDataField(name, OpType.CATEGORICAL, DataType.STRING);
        }
        if (sparkDataType instanceof IntegralType) {
            return this.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER);
        }
        if (sparkDataType instanceof DoubleType) {
            return this.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
        }
        if (sparkDataType instanceof BooleanType) {
            return this.createDataField(name, OpType.CATEGORICAL, DataType.BOOLEAN);
        }
        throw new IllegalArgumentException("Expected string, integral, double or boolean type, got " + sparkDataType.typeName() + " type");
    }

    public void removeDataField(FieldName name) {
        Map dataFields = this.getDataFields();
        DataField dataField = (DataField)dataFields.remove(name);
        if (dataField == null) {
            throw new IllegalArgumentException();
        }
    }

    private void checkColumn(String column) {
        Feature feature;
        List<Feature> features = this.columnFeatures.get(column);
        if (features != null && features.size() > 0 && !((feature = (Feature)Iterables.getOnlyElement(features)) instanceof WildcardFeature)) {
            throw new IllegalArgumentException(column);
        }
    }
}

