/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.PMMLMapper;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.ModelConverter;

public class FeatureMapper
extends PMMLMapper {
    private StructType schema = null;
    private Map<String, List<Feature>> columnFeatures = new LinkedHashMap<String, List<Feature>>();

    public FeatureMapper(StructType schema) {
        this.schema = schema;
    }

    public PMML encodePMML(org.dmg.pmml.Model model) {
        PMML pmml = super.encodePMML(model);
        HashSet<FieldName> names = new HashSet<FieldName>();
        MiningSchema miningSchema = model.getMiningSchema();
        List miningFields = miningSchema.getMiningFields();
        ListIterator miningFieldIt = miningFields.listIterator();
        while (miningFieldIt.hasNext()) {
            MiningField miningField = (MiningField)miningFieldIt.next();
            FieldUsageType fieldUsage = miningField.getUsageType();
            switch (fieldUsage) {
                case ACTIVE: 
                case PREDICTED: 
                case TARGET: {
                    FieldName name = miningField.getName();
                    names.add(name);
                    break;
                }
            }
        }
        DataDictionary dataDictionary = pmml.getDataDictionary();
        List dataFields = dataDictionary.getDataFields();
        ListIterator dataFieldIt = dataFields.listIterator();
        while (dataFieldIt.hasNext()) {
            DataField dataField = (DataField)dataFieldIt.next();
            FieldName name = dataField.getName();
            if (names.contains(name)) continue;
            dataFieldIt.remove();
        }
        return pmml;
    }

    public void append(FeatureConverter<?> converter) {
        Object transformer = converter.getTransformer();
        List<Feature> features = converter.encodeFeatures(this);
        if (transformer instanceof HasOutputCol) {
            HasOutputCol hasOutputCol = (HasOutputCol)transformer;
            String outputCol = hasOutputCol.getOutputCol();
            this.putFeatures(outputCol, features);
        }
    }

    public void append(ModelConverter<?> converter) {
        Object transformer = converter.getTransformer();
        List<Feature> features = converter.encodeFeatures(this);
        if (transformer instanceof HasPredictionCol) {
            HasPredictionCol hasPredictionCol = (HasPredictionCol)transformer;
            String predictionCol = hasPredictionCol.getPredictionCol();
            this.putFeatures(predictionCol, features);
        }
    }

    public Schema createSchema(Model<?> model) {
        FieldName targetField;
        List<String> targetCategories = null;
        if (model instanceof PredictionModel) {
            HasLabelCol hasLabelCol = (HasLabelCol)model;
            Feature feature = this.getOnlyFeature(hasLabelCol.getLabelCol());
            targetField = feature.getName();
            if (model instanceof ClassificationModel || model instanceof GBTClassificationModel || model instanceof MultilayerPerceptronClassificationModel) {
                if (feature instanceof ListFeature) {
                    ListFeature listFeature = (ListFeature)feature;
                    targetCategories = listFeature.getValues();
                } else {
                    ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                    targetCategories = Arrays.asList("0", "1");
                    DataField dataField = this.toCategorical(targetField, targetCategories);
                    ListFeature listFeature = new ListFeature((TypeDefinitionField)dataField, targetCategories);
                    this.columnFeatures.put(hasLabelCol.getLabelCol(), Collections.singletonList(listFeature));
                }
            }
        } else if (model instanceof KMeansModel) {
            targetField = null;
        } else {
            throw new IllegalArgumentException();
        }
        ArrayList activeFields = new ArrayList(this.getDataFields().keySet());
        activeFields.remove(targetField);
        HasFeaturesCol hasFeaturesCol = (HasFeaturesCol)model;
        List<Feature> features = this.getFeatures(hasFeaturesCol.getFeaturesCol());
        if (model instanceof PredictionModel) {
            PredictionModel predictionModel = (PredictionModel)model;
            if (features.size() != predictionModel.numFeatures()) {
                throw new IllegalArgumentException("Expected " + predictionModel.numFeatures() + " features, got " + features.size() + " features");
            }
        }
        Schema result = new Schema(targetField, targetCategories, activeFields, features);
        return result;
    }

    public DataField toCategorical(FieldName name, List<String> categories) {
        DataField dataField = this.getDataField(name);
        if (dataField == null) {
            throw new IllegalArgumentException();
        }
        dataField.setOpType(OpType.CATEGORICAL);
        List values = dataField.getValues();
        if (values.size() > 0) {
            throw new IllegalArgumentException();
        }
        values.addAll(PMMLUtil.createValues(categories));
        return dataField;
    }

    public boolean hasFeatures(String column) {
        return this.columnFeatures.containsKey(column);
    }

    public Feature getOnlyFeature(String column) {
        List<Feature> features = this.getFeatures(column);
        return (Feature)Iterables.getOnlyElement(features);
    }

    public List<Feature> getFeatures(String column) {
        List<Feature> features = this.columnFeatures.get(column);
        if (features == null) {
            ContinuousFeature feature;
            FieldName name = FieldName.create((String)column);
            DataField dataField = this.getDataField(name);
            if (dataField == null) {
                dataField = this.createDataField(name);
            }
            org.dmg.pmml.DataType dataType = dataField.getDataType();
            switch (dataType) {
                case INTEGER: 
                case FLOAT: 
                case DOUBLE: {
                    feature = new ContinuousFeature((TypeDefinitionField)dataField);
                    break;
                }
                default: {
                    feature = new WildcardFeature((TypeDefinitionField)dataField);
                }
            }
            return Collections.singletonList(feature);
        }
        return features;
    }

    public List<Feature> getFeatures(String column, int[] indices) {
        List<Feature> features = this.getFeatures(column);
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < indices.length; ++i) {
            int index = indices[i];
            Feature feature = features.get(index);
            result.add(feature);
        }
        return result;
    }

    public void putFeatures(String column, List<Feature> features) {
        this.checkColumn(column);
        this.columnFeatures.put(column, features);
    }

    public DataField createDataField(FieldName name) {
        org.dmg.pmml.DataType dataType;
        OpType opType;
        StructField field = this.schema.apply(name.getValue());
        DataType sparkDataType = field.dataType();
        if (sparkDataType instanceof NumericType) {
            opType = OpType.CONTINUOUS;
            dataType = sparkDataType instanceof IntegerType ? org.dmg.pmml.DataType.INTEGER : org.dmg.pmml.DataType.DOUBLE;
        } else if (sparkDataType instanceof StringType) {
            opType = OpType.CATEGORICAL;
            dataType = org.dmg.pmml.DataType.STRING;
        } else {
            throw new IllegalArgumentException();
        }
        return this.createDataField(name, opType, dataType);
    }

    public void removeDataField(FieldName name) {
        Map dataFields = this.getDataFields();
        DataField dataField = (DataField)dataFields.remove(name);
        if (dataField == null) {
            throw new IllegalArgumentException();
        }
    }

    private void checkColumn(String column) {
        Feature feature;
        List<Feature> features = this.columnFeatures.get(column);
        if (features != null && features.size() > 0 && !((feature = (Feature)Iterables.getOnlyElement(features)) instanceof WildcardFeature)) {
            throw new IllegalArgumentException(column);
        }
    }
}

