/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.PMMLMapper;
import org.jpmml.converter.Schema;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sparkml.FeatureConverter;

public class FeatureMapper
extends PMMLMapper {
    private StructType schema = null;
    private Map<String, List<Feature>> columnFeatures = new LinkedHashMap<String, List<Feature>>();

    public FeatureMapper(StructType schema) {
        this.schema = schema;
    }

    public void append(FeatureConverter<?> converter) {
        Object transformer = converter.getTransformer();
        List<Feature> features = converter.encodeFeatures(this);
        if (transformer instanceof HasOutputCol) {
            HasOutputCol hasOutputCol = (HasOutputCol)transformer;
            String outputCol = hasOutputCol.getOutputCol();
            this.columnFeatures.put(outputCol, features);
        }
    }

    public Schema createSchema(Model<?> model) {
        FieldName targetField;
        List targetCategories = null;
        if (model instanceof PredictionModel) {
            HasLabelCol hasLabelCol = (HasLabelCol)model;
            Feature feature = this.getOnlyFeature(hasLabelCol.getLabelCol());
            targetField = feature.getName();
            if (model instanceof ClassificationModel || model instanceof GBTClassificationModel) {
                ListFeature listFeature = (ListFeature)feature;
                targetCategories = listFeature.getValues();
            }
        } else if (model instanceof KMeansModel) {
            targetField = null;
        } else {
            throw new IllegalArgumentException();
        }
        ArrayList activeFields = new ArrayList(this.getDataFields().keySet());
        activeFields.remove(targetField);
        HasFeaturesCol hasFeaturesCol = (HasFeaturesCol)model;
        List<Feature> features = this.getFeatures(hasFeaturesCol.getFeaturesCol());
        if (model instanceof PredictionModel) {
            PredictionModel predictionModel = (PredictionModel)model;
            if (features.size() != predictionModel.numFeatures()) {
                throw new IllegalArgumentException();
            }
        }
        Schema result = new Schema(targetField, targetCategories, activeFields, features);
        return result;
    }

    public Feature getOnlyFeature(String column) {
        List<Feature> features = this.getFeatures(column);
        return (Feature)Iterables.getOnlyElement(features);
    }

    public List<Feature> getFeatures(String column) {
        List<Feature> features = this.columnFeatures.get(column);
        if (features == null) {
            ContinuousFeature feature;
            FieldName name = FieldName.create((String)column);
            DataField dataField = this.getDataField(name);
            if (dataField == null) {
                dataField = this.createDataField(name);
            }
            DataType dataType = dataField.getDataType();
            switch (dataType) {
                case INTEGER: 
                case FLOAT: 
                case DOUBLE: {
                    feature = new ContinuousFeature((TypeDefinitionField)dataField);
                    break;
                }
                default: {
                    feature = new WildcardFeature((TypeDefinitionField)dataField);
                }
            }
            return Collections.singletonList(feature);
        }
        return features;
    }

    public List<Feature> getFeatures(String column, int[] indices) {
        List<Feature> features = this.getFeatures(column);
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < indices.length; ++i) {
            int index = indices[i];
            Feature feature = features.get(index);
            result.add(feature);
        }
        return result;
    }

    public DataField createDataField(FieldName name) {
        DataType dataType;
        OpType opType;
        StructField field = this.schema.apply(name.getValue());
        org.apache.spark.sql.types.DataType sparkDataType = field.dataType();
        if (sparkDataType instanceof NumericType) {
            opType = OpType.CONTINUOUS;
            dataType = sparkDataType instanceof IntegerType ? DataType.INTEGER : DataType.DOUBLE;
        } else if (sparkDataType instanceof StringType) {
            opType = OpType.CATEGORICAL;
            dataType = DataType.STRING;
        } else {
            throw new IllegalArgumentException();
        }
        return this.createDataField(name, opType, dataType);
    }
}

