/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.xml.bind.JAXBException;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.Binarizer;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.ml.feature.ChiSqSelectorModel;
import org.apache.spark.ml.feature.ColumnPruner;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.MinMaxScalerModel;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.PCAModel;
import org.apache.spark.ml.feature.RFormulaModel;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorAttributeRewriter;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.feature.VectorSlicer;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.MetroJAXBUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.FeatureMapper;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.RegressionModelConverter;
import org.jpmml.sparkml.TransformerConverter;
import org.jpmml.sparkml.feature.BinarizerConverter;
import org.jpmml.sparkml.feature.BucketizerConverter;
import org.jpmml.sparkml.feature.ChiSqSelectorModelConverter;
import org.jpmml.sparkml.feature.ColumnPrunerConverter;
import org.jpmml.sparkml.feature.IndexToStringConverter;
import org.jpmml.sparkml.feature.MinMaxScalerModelConverter;
import org.jpmml.sparkml.feature.OneHotEncoderConverter;
import org.jpmml.sparkml.feature.PCAModelConverter;
import org.jpmml.sparkml.feature.RFormulaModelConverter;
import org.jpmml.sparkml.feature.StandardScalerModelConverter;
import org.jpmml.sparkml.feature.StringIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorAssemblerConverter;
import org.jpmml.sparkml.feature.VectorAttributeRewriterConverter;
import org.jpmml.sparkml.feature.VectorIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorSlicerConverter;
import org.jpmml.sparkml.model.DecisionTreeClassificationModelConverter;
import org.jpmml.sparkml.model.DecisionTreeRegressionModelConverter;
import org.jpmml.sparkml.model.GBTClassificationModelConverter;
import org.jpmml.sparkml.model.GBTRegressionModelConverter;
import org.jpmml.sparkml.model.KMeansModelConverter;
import org.jpmml.sparkml.model.LinearRegressionModelConverter;
import org.jpmml.sparkml.model.LogisticRegressionModelConverter;
import org.jpmml.sparkml.model.MultilayerPerceptronClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestRegressionModelConverter;

public class ConverterUtil {
    private static final Map<Class<? extends Transformer>, Class<? extends TransformerConverter>> converters = new LinkedHashMap<Class<? extends Transformer>, Class<? extends TransformerConverter>>();

    private ConverterUtil() {
    }

    public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
        org.dmg.pmml.Model rootModel;
        Transformer[] stages;
        FeatureMapper featureMapper = new FeatureMapper(schema);
        LinkedHashMap<String, org.dmg.pmml.Model> models = new LinkedHashMap<String, org.dmg.pmml.Model>();
        for (Transformer stage : stages = pipelineModel.stages()) {
            TransformerConverter<Transformer> converter = ConverterUtil.createConverter(stage);
            if (converter instanceof FeatureConverter) {
                FeatureConverter featureConverter = (FeatureConverter)converter;
                featureMapper.append(featureConverter);
                continue;
            }
            if (converter instanceof ModelConverter) {
                ModelConverter modelConverter = (ModelConverter)converter;
                Schema featureSchema = featureMapper.createSchema((Model)stage);
                if (converter instanceof RegressionModelConverter) {
                    FieldName targetField = featureSchema.getTargetField();
                    DataField dataField = featureMapper.getDataField(targetField);
                    dataField.setOpType(OpType.CONTINUOUS).setDataType(DataType.DOUBLE);
                }
                org.dmg.pmml.Model model = modelConverter.encodeModel(featureSchema);
                featureMapper.append(modelConverter);
                models.put(((HasPredictionCol)stage).getPredictionCol(), model);
                continue;
            }
            throw new IllegalArgumentException();
        }
        if (models.size() == 1) {
            rootModel = (org.dmg.pmml.Model)Iterables.getOnlyElement(models.values());
        } else if (models.size() >= 2) {
            ArrayList<MiningField> targetMiningFields = new ArrayList<MiningField>();
            ArrayList entries = new ArrayList(models.entrySet());
            Iterator entryIt = entries.iterator();
            while (entryIt.hasNext()) {
                Map.Entry entry = (Map.Entry)entryIt.next();
                String predictionCol = (String)entry.getKey();
                org.dmg.pmml.Model model = (org.dmg.pmml.Model)entry.getValue();
                MiningSchema miningSchema = model.getMiningSchema();
                List miningFields = miningSchema.getMiningFields();
                for (MiningField miningField : miningFields) {
                    MiningField.FieldUsage fieldUsage = miningField.getFieldUsage();
                    switch (fieldUsage) {
                        case PREDICTED: 
                        case TARGET: {
                            targetMiningFields.add(miningField);
                            break;
                        }
                    }
                }
                if (!entryIt.hasNext()) break;
                FieldName name = FieldName.create((String)predictionCol);
                DataField dataField = featureMapper.getDataField(name);
                if (dataField == null) {
                    throw new IllegalArgumentException();
                }
                featureMapper.removeDataField(name);
                Output output = model.getOutput();
                if (output == null) {
                    output = new Output();
                    model.setOutput(output);
                }
                OutputField outputField = new OutputField(name, dataField.getDataType()).setOpType(dataField.getOpType()).setResultFeature(ResultFeature.PREDICTED_VALUE);
                output.addOutputFields(new OutputField[]{outputField});
            }
            MiningSchema miningSchema = new MiningSchema(targetMiningFields);
            ArrayList memberModels = new ArrayList(models.values());
            MiningModel miningModel = MiningModelUtil.createModelChain(null, Collections.emptyList(), memberModels).setMiningSchema(miningSchema);
            rootModel = miningModel;
        } else {
            throw new IllegalArgumentException();
        }
        PMML pmml = featureMapper.encodePMML(rootModel);
        return pmml;
    }

    public static byte[] toPMMLByteArray(StructType schema, PipelineModel pipelineModel) {
        PMML pmml = ConverterUtil.toPMML(schema, pipelineModel);
        ByteArrayOutputStream os = new ByteArrayOutputStream(0x100000);
        try {
            MetroJAXBUtil.marshalPMML((PMML)pmml, (OutputStream)os);
        }
        catch (JAXBException je) {
            throw new RuntimeException(je);
        }
        return os.toByteArray();
    }

    public static FeatureConverter<?> createFeatureConverter(Transformer transformer) {
        return (FeatureConverter)ConverterUtil.createConverter(transformer);
    }

    public static ModelConverter<?> createModelConverter(Transformer transformer) {
        return (ModelConverter)ConverterUtil.createConverter(transformer);
    }

    public static <T extends Transformer> TransformerConverter<T> createConverter(T transformer) {
        Class<?> clazz = transformer.getClass();
        Class<? extends TransformerConverter> converterClazz = ConverterUtil.getConverterClazz(clazz);
        if (converterClazz == null) {
            throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not supported");
        }
        try {
            Constructor<? extends TransformerConverter> constructor = converterClazz.getDeclaredConstructor(clazz);
            return constructor.newInstance(transformer);
        }
        catch (Exception e) {
            throw new IllegalArgumentException(e);
        }
    }

    public static Class<? extends TransformerConverter> getConverterClazz(Class<? extends Transformer> clazz) {
        return converters.get(clazz);
    }

    public static void putConverterClazz(Class<? extends Transformer> clazz, Class<? extends TransformerConverter<?>> converterClazz) {
        converters.put(clazz, converterClazz);
    }

    static {
        converters.put(Binarizer.class, BinarizerConverter.class);
        converters.put(Bucketizer.class, BucketizerConverter.class);
        converters.put(ChiSqSelectorModel.class, ChiSqSelectorModelConverter.class);
        converters.put(ColumnPruner.class, ColumnPrunerConverter.class);
        converters.put(IndexToString.class, IndexToStringConverter.class);
        converters.put(MinMaxScalerModel.class, MinMaxScalerModelConverter.class);
        converters.put(OneHotEncoder.class, OneHotEncoderConverter.class);
        converters.put(PCAModel.class, PCAModelConverter.class);
        converters.put(RFormulaModel.class, RFormulaModelConverter.class);
        converters.put(StandardScalerModel.class, StandardScalerModelConverter.class);
        converters.put(StringIndexerModel.class, StringIndexerModelConverter.class);
        converters.put(VectorAssembler.class, VectorAssemblerConverter.class);
        converters.put(VectorAttributeRewriter.class, VectorAttributeRewriterConverter.class);
        converters.put(VectorIndexerModel.class, VectorIndexerModelConverter.class);
        converters.put(VectorSlicer.class, VectorSlicerConverter.class);
        converters.put(DecisionTreeClassificationModel.class, DecisionTreeClassificationModelConverter.class);
        converters.put(DecisionTreeRegressionModel.class, DecisionTreeRegressionModelConverter.class);
        converters.put(GBTClassificationModel.class, GBTClassificationModelConverter.class);
        converters.put(GBTRegressionModel.class, GBTRegressionModelConverter.class);
        converters.put(KMeansModel.class, KMeansModelConverter.class);
        converters.put(LinearRegressionModel.class, LinearRegressionModelConverter.class);
        converters.put(LogisticRegressionModel.class, LogisticRegressionModelConverter.class);
        converters.put(MultilayerPerceptronClassificationModel.class, MultilayerPerceptronClassificationModelConverter.class);
        converters.put(RandomForestClassificationModel.class, RandomForestClassificationModelConverter.class);
        converters.put(RandomForestRegressionModel.class, RandomForestRegressionModelConverter.class);
    }
}

