/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import java.lang.reflect.Constructor;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.Binarizer;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.ml.feature.ChiSqSelectorModel;
import org.apache.spark.ml.feature.ColumnPruner;
import org.apache.spark.ml.feature.MinMaxScalerModel;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.PCAModel;
import org.apache.spark.ml.feature.RFormulaModel;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorAttributeRewriter;
import org.apache.spark.ml.feature.VectorSlicer;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.FeatureMapper;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.TransformerConverter;
import org.jpmml.sparkml.feature.BinarizerConverter;
import org.jpmml.sparkml.feature.BucketizerConverter;
import org.jpmml.sparkml.feature.ChiSqSelectorModelConverter;
import org.jpmml.sparkml.feature.ColumnPrunerConverter;
import org.jpmml.sparkml.feature.MinMaxScalerModelConverter;
import org.jpmml.sparkml.feature.OneHotEncoderConverter;
import org.jpmml.sparkml.feature.PCAModelConverter;
import org.jpmml.sparkml.feature.RFormulaModelConverter;
import org.jpmml.sparkml.feature.StandardScalerModelConverter;
import org.jpmml.sparkml.feature.StringIndexerModelConverter;
import org.jpmml.sparkml.feature.VectorAssemblerConverter;
import org.jpmml.sparkml.feature.VectorAttributeRewriterConverter;
import org.jpmml.sparkml.feature.VectorSlicerConverter;
import org.jpmml.sparkml.model.DecisionTreeClassificationModelConverter;
import org.jpmml.sparkml.model.DecisionTreeRegressionModelConverter;
import org.jpmml.sparkml.model.GBTClassificationModelConverter;
import org.jpmml.sparkml.model.GBTRegressionModelConverter;
import org.jpmml.sparkml.model.KMeansModelConverter;
import org.jpmml.sparkml.model.LinearRegressionModelConverter;
import org.jpmml.sparkml.model.LogisticRegressionModelConverter;
import org.jpmml.sparkml.model.RandomForestClassificationModelConverter;
import org.jpmml.sparkml.model.RandomForestRegressionModelConverter;

public class ConverterUtil {
    private static final Map<Class<? extends Transformer>, Class<? extends TransformerConverter>> converters = new LinkedHashMap<Class<? extends Transformer>, Class<? extends TransformerConverter>>();

    private ConverterUtil() {
    }

    public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
        Transformer[] stages;
        FeatureMapper featureMapper = new FeatureMapper(schema);
        for (Transformer stage : stages = pipelineModel.stages()) {
            TransformerConverter<Transformer> converter = ConverterUtil.createConverter(stage);
            if (!(converter instanceof FeatureConverter)) {
                if (converter instanceof ModelConverter) {
                    ModelConverter modelConverter = (ModelConverter)converter;
                    Schema featureSchema = featureMapper.createSchema((Model)stage);
                    org.dmg.pmml.Model model = modelConverter.encodeModel(featureSchema);
                    PMML pmml = featureMapper.encodePMML(model).setHeader(PMMLUtil.createHeader((String)"JPMML-SparkML", (String)"1.0-SNAPSHOT"));
                    return pmml;
                }
                throw new IllegalArgumentException();
            }
            FeatureConverter featureConverter = (FeatureConverter)converter;
            featureMapper.append(featureConverter);
        }
        throw new IllegalArgumentException();
    }

    public static FeatureConverter<?> createFeatureConverter(Transformer transformer) {
        return (FeatureConverter)ConverterUtil.createConverter(transformer);
    }

    public static ModelConverter<?> createModelConverter(Transformer transformer) {
        return (ModelConverter)ConverterUtil.createConverter(transformer);
    }

    public static <T extends Transformer> TransformerConverter<T> createConverter(T transformer) {
        Class<?> clazz = transformer.getClass();
        Class<? extends TransformerConverter> converterClazz = ConverterUtil.getConverterClazz(clazz);
        if (converterClazz == null) {
            throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not supported");
        }
        try {
            Constructor<? extends TransformerConverter> constructor = converterClazz.getDeclaredConstructor(clazz);
            return constructor.newInstance(transformer);
        }
        catch (Exception e) {
            throw new IllegalArgumentException(e);
        }
    }

    public static Class<? extends TransformerConverter> getConverterClazz(Class<? extends Transformer> clazz) {
        return converters.get(clazz);
    }

    public static void putConverterClazz(Class<? extends Transformer> clazz, Class<? extends TransformerConverter<?>> converterClazz) {
        converters.put(clazz, converterClazz);
    }

    static {
        converters.put(Binarizer.class, BinarizerConverter.class);
        converters.put(Bucketizer.class, BucketizerConverter.class);
        converters.put(ChiSqSelectorModel.class, ChiSqSelectorModelConverter.class);
        converters.put(ColumnPruner.class, ColumnPrunerConverter.class);
        converters.put(MinMaxScalerModel.class, MinMaxScalerModelConverter.class);
        converters.put(OneHotEncoder.class, OneHotEncoderConverter.class);
        converters.put(PCAModel.class, PCAModelConverter.class);
        converters.put(RFormulaModel.class, RFormulaModelConverter.class);
        converters.put(StandardScalerModel.class, StandardScalerModelConverter.class);
        converters.put(StringIndexerModel.class, StringIndexerModelConverter.class);
        converters.put(VectorAssembler.class, VectorAssemblerConverter.class);
        converters.put(VectorAttributeRewriter.class, VectorAttributeRewriterConverter.class);
        converters.put(VectorSlicer.class, VectorSlicerConverter.class);
        converters.put(DecisionTreeClassificationModel.class, DecisionTreeClassificationModelConverter.class);
        converters.put(DecisionTreeRegressionModel.class, DecisionTreeRegressionModelConverter.class);
        converters.put(GBTClassificationModel.class, GBTClassificationModelConverter.class);
        converters.put(GBTRegressionModel.class, GBTRegressionModelConverter.class);
        converters.put(KMeansModel.class, KMeansModelConverter.class);
        converters.put(LinearRegressionModel.class, LinearRegressionModelConverter.class);
        converters.put(LogisticRegressionModel.class, LogisticRegressionModelConverter.class);
        converters.put(RandomForestClassificationModel.class, RandomForestClassificationModelConverter.class);
        converters.put(RandomForestRegressionModel.class, RandomForestRegressionModelConverter.class);
    }
}

