/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.xml.bind.JAXBException;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.MetroJAXBUtil;
import org.jpmml.sparkml.ConverterFactory;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.RegexKey;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.TransformerConverter;

public class PMMLBuilder {
    private StructType schema = null;
    private PipelineModel pipelineModel = null;
    private Map<RegexKey, Map<String, Object>> options = new LinkedHashMap<RegexKey, Map<String, Object>>();
    private Verification verification = null;

    public PMMLBuilder(StructType schema, PipelineModel pipelineModel) {
        this.setSchema(schema);
        this.setPipelineModel(pipelineModel);
    }

    public PMML build() {
        Model rootModel;
        StructType schema = this.getSchema();
        PipelineModel pipelineModel = this.getPipelineModel();
        Map<RegexKey, Map<String, Object>> options = this.getOptions();
        Verification verification = this.getVerification();
        ConverterFactory converterFactory = new ConverterFactory(options);
        SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory);
        Map derivedFields = encoder.getDerivedFields();
        ArrayList<Model> models = new ArrayList<Model>();
        ArrayList<String> predictionColumns = new ArrayList<String>();
        ArrayList<String> probabilityColumns = new ArrayList<String>();
        List preProcessorNames = Collections.emptyList();
        Iterable<Transformer> transformers = PMMLBuilder.getTransformers(pipelineModel);
        for (Transformer transformer : transformers) {
            TransformerConverter<?> converter = converterFactory.newConverter(transformer);
            if (converter instanceof FeatureConverter) {
                FeatureConverter featureConverter = (FeatureConverter)converter;
                featureConverter.registerFeatures(encoder);
                continue;
            }
            if (converter instanceof ModelConverter) {
                ModelConverter modelConverter = (ModelConverter)converter;
                Model model = modelConverter.registerModel(encoder);
                models.add(model);
                if (transformer instanceof HasPredictionCol) {
                    HasPredictionCol hasPredictionCol = (HasPredictionCol)transformer;
                    if (!(transformer instanceof GeneralizedLinearRegressionModel) || !MiningFunction.CLASSIFICATION.equals((Object)model.getMiningFunction())) {
                        predictionColumns.add(hasPredictionCol.getPredictionCol());
                    }
                }
                if (transformer instanceof HasProbabilityCol) {
                    HasProbabilityCol hasProbabilityCol = (HasProbabilityCol)transformer;
                    probabilityColumns.add(hasProbabilityCol.getProbabilityCol());
                }
                preProcessorNames = new ArrayList(derivedFields.keySet());
                continue;
            }
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
        ArrayList postProcessorNames = new ArrayList(derivedFields.keySet());
        postProcessorNames.removeAll(preProcessorNames);
        if (models.size() == 1) {
            rootModel = (Model)Iterables.getOnlyElement(models);
        } else if (models.size() > 1) {
            ArrayList targetMiningFields = new ArrayList();
            for (Model model : models) {
                MiningSchema miningSchema = model.getMiningSchema();
                List miningFields = miningSchema.getMiningFields();
                for (MiningField miningField : miningFields) {
                    MiningField.UsageType usageType = miningField.getUsageType();
                    switch (usageType) {
                        case PREDICTED: 
                        case TARGET: {
                            targetMiningFields.add(miningField);
                            break;
                        }
                    }
                }
            }
            MiningSchema miningSchema = new MiningSchema((List)targetMiningFields);
            MiningModel miningModel = MiningModelUtil.createModelChain(models, (Schema)new Schema(null, Collections.emptyList())).setMiningSchema(miningSchema);
            rootModel = miningModel;
        } else {
            throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
        }
        for (FieldName postProcessorName : postProcessorNames) {
            DerivedField derivedField = (DerivedField)derivedFields.get(postProcessorName);
            encoder.removeDerivedField(postProcessorName);
            Output output = ModelUtil.ensureOutput((Model)rootModel);
            OutputField outputField = new OutputField(derivedField.getName(), derivedField.getDataType()).setOpType(derivedField.getOpType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression());
            output.addOutputFields(new OutputField[]{outputField});
        }
        PMML pmml = encoder.encodePMML(rootModel);
        if ((predictionColumns.size() > 0 || probabilityColumns.size() > 0) && verification != null) {
            Dataset<Row> dataset = verification.getDataset();
            Dataset<Row> transformedDataset = verification.getTransformedDataset();
            Double precision = verification.getPrecision();
            Double zeroThreshold = verification.getZeroThreshold();
            ArrayList<String> inputColumns = new ArrayList<String>();
            MiningSchema miningSchema = rootModel.getMiningSchema();
            List miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                MiningField.UsageType usageType = miningField.getUsageType();
                switch (usageType) {
                    case ACTIVE: {
                        FieldName name = miningField.getName();
                        inputColumns.add(name.getValue());
                        break;
                    }
                }
            }
            LinkedHashMap data = new LinkedHashMap();
            for (String inputColumn : inputColumns) {
                VerificationField verificationField = ModelUtil.createVerificationField((FieldName)FieldName.create((String)inputColumn));
                data.put(verificationField, PMMLBuilder.getColumn(dataset, inputColumn));
            }
            for (String predictionColumn : predictionColumns) {
                Feature feature = encoder.getOnlyFeature(predictionColumn);
                VerificationField verificationField = ModelUtil.createVerificationField((FieldName)feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
                data.put(verificationField, PMMLBuilder.getColumn(transformedDataset, predictionColumn));
            }
            for (String probabilityColumn : probabilityColumns) {
                List<Feature> features = encoder.getFeatures(probabilityColumn);
                for (int i = 0; i < features.size(); ++i) {
                    Feature feature = features.get(i);
                    VerificationField verificationField = ModelUtil.createVerificationField((FieldName)feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
                    data.put(verificationField, PMMLBuilder.getVectorColumn(transformedDataset, probabilityColumn, i));
                }
            }
            rootModel.setModelVerification(ModelUtil.createModelVerification(data));
        }
        return pmml;
    }

    public byte[] buildByteArray() {
        return this.buildByteArray(0x100000);
    }

    private byte[] buildByteArray(int size) {
        PMML pmml = this.build();
        ByteArrayOutputStream os = new ByteArrayOutputStream(size);
        try {
            MetroJAXBUtil.marshalPMML((PMML)pmml, (OutputStream)os);
        }
        catch (JAXBException je) {
            throw new RuntimeException(je);
        }
        return os.toByteArray();
    }

    public File buildFile(File file) throws IOException {
        PMML pmml = this.build();
        try (FileOutputStream os = new FileOutputStream(file);){
            MetroJAXBUtil.marshalPMML((PMML)pmml, (OutputStream)os);
        }
        return file;
    }

    public PMMLBuilder putOption(String key, Object value) {
        return this.putOptions(Collections.singletonMap(key, value));
    }

    public PMMLBuilder putOptions(Map<String, ?> map) {
        return this.putOptions(Pattern.compile(".*"), map);
    }

    public PMMLBuilder putOption(PipelineStage pipelineStage, String key, Object value) {
        return this.putOptions(pipelineStage, Collections.singletonMap(key, value));
    }

    public PMMLBuilder putOptions(PipelineStage pipelineStage, Map<String, ?> map) {
        return this.putOptions(Pattern.compile(pipelineStage.uid(), 16), map);
    }

    public PMMLBuilder putOptions(Pattern pattern, Map<String, ?> map) {
        RegexKey key;
        Map<RegexKey, Map<String, Object>> options = this.getOptions();
        Map<String, Object> patternOptions = options.get(key = new RegexKey(pattern));
        if (patternOptions == null) {
            patternOptions = new LinkedHashMap<String, Object>();
            options.put(key, patternOptions);
        }
        patternOptions.putAll(map);
        return this;
    }

    public PMMLBuilder verify(Dataset<Row> dataset) {
        return this.verify(dataset, 1.0E-14, 1.0E-14);
    }

    public PMMLBuilder verify(Dataset<Row> dataset, double precision, double zeroThreshold) {
        PipelineModel pipelineModel = this.getPipelineModel();
        Dataset transformedDataset = pipelineModel.transform(dataset);
        Verification verification = new Verification(dataset, transformedDataset).setPrecision(precision).setZeroThreshold(zeroThreshold);
        return this.setVerification(verification);
    }

    public StructType getSchema() {
        return this.schema;
    }

    public PMMLBuilder setSchema(StructType schema) {
        if (schema == null) {
            throw new IllegalArgumentException();
        }
        this.schema = schema;
        return this;
    }

    public PipelineModel getPipelineModel() {
        return this.pipelineModel;
    }

    public PMMLBuilder setPipelineModel(PipelineModel pipelineModel) {
        if (pipelineModel == null) {
            throw new IllegalArgumentException();
        }
        this.pipelineModel = pipelineModel;
        return this;
    }

    public Map<RegexKey, Map<String, Object>> getOptions() {
        return this.options;
    }

    private PMMLBuilder setOptions(Map<RegexKey, Map<String, Object>> options) {
        if (options == null) {
            throw new IllegalArgumentException();
        }
        this.options = options;
        return this;
    }

    public Verification getVerification() {
        return this.verification;
    }

    private PMMLBuilder setVerification(Verification verification) {
        this.verification = verification;
        return this;
    }

    private static Iterable<Transformer> getTransformers(PipelineModel pipelineModel) {
        boolean modified;
        ArrayList<Transformer> result = new ArrayList<Transformer>();
        result.add((Transformer)pipelineModel);
        Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>(){

            @Override
            public List<Transformer> apply(Transformer transformer) {
                if (transformer instanceof PipelineModel) {
                    PipelineModel pipelineModel = (PipelineModel)transformer;
                    return Arrays.asList(pipelineModel.stages());
                }
                if (transformer instanceof CrossValidatorModel) {
                    CrossValidatorModel crossValidatorModel = (CrossValidatorModel)transformer;
                    return Collections.singletonList(crossValidatorModel.bestModel());
                }
                if (transformer instanceof TrainValidationSplitModel) {
                    TrainValidationSplitModel trainValidationSplitModel = (TrainValidationSplitModel)transformer;
                    return Collections.singletonList(trainValidationSplitModel.bestModel());
                }
                return null;
            }
        };
        do {
            modified = false;
            ListIterator<Transformer> transformerIt = result.listIterator();
            while (transformerIt.hasNext()) {
                Transformer transformer = (Transformer)transformerIt.next();
                List childTransformers = (List)function.apply(transformer);
                if (childTransformers == null) continue;
                transformerIt.remove();
                for (Transformer childTransformer : childTransformers) {
                    transformerIt.add(childTransformer);
                }
                modified = true;
            }
        } while (modified);
        return result;
    }

    private static List<?> getColumn(Dataset<Row> dataset, String name) {
        List rows = dataset.select(name, new String[0]).collectAsList();
        return rows.stream().map(row -> row.apply(0)).collect(Collectors.toList());
    }

    private static List<?> getVectorColumn(Dataset<Row> dataset, String name, int index) {
        List<?> column = PMMLBuilder.getColumn(dataset, name);
        return column.stream().map(vector -> vector.apply(index)).collect(Collectors.toList());
    }

    private static void init() {
        ConverterFactory.checkVersion();
        ConverterFactory.checkApplicationClasspath();
        ConverterFactory.checkNoShading();
    }

    static {
        PMMLBuilder.init();
    }

    public static class Verification {
        private Dataset<Row> dataset = null;
        private Dataset<Row> transformedDataset = null;
        public Double precision = null;
        public Double zeroThreshold = null;

        private Verification(Dataset<Row> dataset, Dataset<Row> transformedDataset) {
            this.setDataset(dataset);
            this.setTransformedDataset(transformedDataset);
        }

        public Dataset<Row> getDataset() {
            return this.dataset;
        }

        private Verification setDataset(Dataset<Row> dataset) {
            this.dataset = dataset;
            return this;
        }

        public Dataset<Row> getTransformedDataset() {
            return this.transformedDataset;
        }

        private Verification setTransformedDataset(Dataset<Row> transformedDataset) {
            this.transformedDataset = transformedDataset;
            return this;
        }

        public Double getPrecision() {
            return this.precision;
        }

        public Verification setPrecision(Double precision) {
            this.precision = precision;
            return this;
        }

        public Double getZeroThreshold() {
            return this.zeroThreshold;
        }

        public Verification setZeroThreshold(Double zeroThreshold) {
            this.zeroThreshold = zeroThreshold;
            return this;
        }
    }
}

