/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.IntegralType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.association.Item;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sparkml.ConverterFactory;

public class SparkMLEncoder
extends ModelEncoder {
    private StructType schema = null;
    private ConverterFactory converterFactory = null;
    private Map<String, List<Feature>> columnFeatures = new LinkedHashMap<String, List<Feature>>();

    public SparkMLEncoder(StructType schema, ConverterFactory converterFactory) {
        this.setSchema(schema);
        this.setConverterFactory(converterFactory);
    }

    public PMML encodePMML(Model model) {
        PMML pmml = super.encodePMML(model);
        AbstractVisitor visitor = new AbstractVisitor(){

            public VisitorAction visit(Item item) {
                item.setField(null);
                return super.visit(item);
            }
        };
        visitor.applyTo((Visitable)pmml);
        return pmml;
    }

    public boolean hasFeatures(String column) {
        return this.columnFeatures.containsKey(column);
    }

    public Feature getOnlyFeature(String column) {
        List<Feature> features = this.getFeatures(column);
        return (Feature)Iterables.getOnlyElement(features);
    }

    public List<Feature> getFeatures(String column) {
        List<Feature> features = this.columnFeatures.get(column);
        if (features == null) {
            WildcardFeature feature;
            FieldName name = FieldName.create((String)column);
            DataField dataField = this.getDataField(name);
            if (dataField == null) {
                dataField = this.createDataField(name);
            }
            DataType dataType = dataField.getDataType();
            switch (dataType) {
                case STRING: {
                    feature = new WildcardFeature((PMMLEncoder)this, dataField);
                    break;
                }
                case INTEGER: 
                case DOUBLE: {
                    feature = new ContinuousFeature((PMMLEncoder)this, (Field)dataField);
                    break;
                }
                case BOOLEAN: {
                    feature = new BooleanFeature((PMMLEncoder)this, (Field)dataField);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Data type " + dataType + " is not supported");
                }
            }
            return Collections.singletonList(feature);
        }
        return features;
    }

    public List<Feature> getFeatures(String column, int[] indices) {
        List<Feature> features = this.getFeatures(column);
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < indices.length; ++i) {
            int index = indices[i];
            Feature feature = features.get(index);
            result.add(feature);
        }
        return result;
    }

    public void putOnlyFeature(String column, Feature feature) {
        this.putFeatures(column, Collections.singletonList(feature));
    }

    public void putFeatures(String column, List<Feature> features) {
        List<Feature> existingFeatures = this.columnFeatures.get(column);
        if (existingFeatures != null && existingFeatures.size() > 0) {
            SchemaUtil.checkSize((int)existingFeatures.size(), features);
            for (int i = 0; i < existingFeatures.size(); ++i) {
                Feature existingFeature = existingFeatures.get(i);
                Feature feature = features.get(i);
                if (feature.getName().equals((Object)existingFeature.getName())) continue;
                throw new IllegalArgumentException("Expected feature column '" + existingFeature.getName() + "', got feature column '" + feature.getName() + "'");
            }
        }
        this.columnFeatures.put(column, features);
    }

    public DataField createDataField(FieldName name) {
        StructType schema = this.getSchema();
        StructField field = schema.apply(name.getValue());
        org.apache.spark.sql.types.DataType sparkDataType = field.dataType();
        if (sparkDataType instanceof StringType) {
            return this.createDataField(name, OpType.CATEGORICAL, DataType.STRING);
        }
        if (sparkDataType instanceof IntegralType) {
            return this.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER);
        }
        if (sparkDataType instanceof DoubleType) {
            return this.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
        }
        if (sparkDataType instanceof BooleanType) {
            return this.createDataField(name, OpType.CATEGORICAL, DataType.BOOLEAN);
        }
        throw new IllegalArgumentException("Expected string, integral, double or boolean data type, got " + sparkDataType.typeName() + " data type");
    }

    public StructType getSchema() {
        return this.schema;
    }

    private void setSchema(StructType schema) {
        this.schema = Objects.requireNonNull(schema);
    }

    public ConverterFactory getConverterFactory() {
        return this.converterFactory;
    }

    private void setConverterFactory(ConverterFactory converterFactory) {
        this.converterFactory = Objects.requireNonNull(converterFactory);
    }
}

