/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.testing;

import com.google.common.base.Equivalence;
import com.google.common.io.ByteStreams;
import com.google.common.io.MoreFiles;
import com.google.common.io.RecursiveDeleteOption;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.testing.ModelEncoderBatch;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.PMMLEquivalence;
import org.jpmml.sparkml.DatasetUtil;
import org.jpmml.sparkml.PMMLBuilder;
import org.jpmml.sparkml.PipelineModelUtil;
import org.jpmml.sparkml.testing.SparkMLEncoderBatchTest;

public abstract class SparkMLEncoderBatch
extends ModelEncoderBatch {
    public SparkMLEncoderBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence) {
        super(algorithm, dataset, columnFilter, equivalence);
    }

    public abstract SparkMLEncoderBatchTest getArchiveBatchTest();

    public List<Map<String, Object>> getOptionsMatrix() {
        LinkedHashMap<String, Integer> options = new LinkedHashMap<String, Integer>();
        options.put("lookup_threshold", 5);
        return Collections.singletonList(options);
    }

    public String getSchemaJsonPath() {
        return "/schema/" + this.getDataset() + ".json";
    }

    public String getPipelineZipPath() {
        return "/pipeline/" + this.getAlgorithm() + this.getDataset() + ".zip";
    }

    public Dataset<Row> getVerificationDataset(Dataset<Row> inputDataset) {
        return inputDataset.sample(false, 0.05, 63317L);
    }

    public PMML getPMML() throws Exception {
        Dataset<Row> inputDataset;
        PipelineModel pipelineModel;
        Serializable tmpZipFile2;
        StructType schema;
        Serializable tmpSchemaFile;
        SparkMLEncoderBatchTest archiveBatchTest = this.getArchiveBatchTest();
        SparkSession sparkSession = archiveBatchTest.getSparkSession();
        if (sparkSession == null) {
            throw new IllegalStateException();
        }
        ArrayList<File> tmpResources = new ArrayList<File>();
        try (InputStream is = this.open(this.getSchemaJsonPath());){
            tmpSchemaFile = SparkMLEncoderBatch.toTmpFile(is, this.getDataset(), ".json");
            tmpResources.add((File)tmpSchemaFile);
            schema = DatasetUtil.loadSchema((File)tmpSchemaFile);
        }
        InputStream is = this.open(this.getPipelineZipPath());
        tmpSchemaFile = null;
        try {
            tmpZipFile2 = SparkMLEncoderBatch.toTmpFile(is, this.getAlgorithm() + this.getDataset(), ".zip");
            tmpResources.add((File)tmpZipFile2);
            File tmpPipelineDir = PipelineModelUtil.uncompress((File)tmpZipFile2);
            tmpResources.add(tmpPipelineDir);
            pipelineModel = PipelineModelUtil.load(sparkSession, tmpPipelineDir);
        }
        catch (Throwable tmpZipFile2) {
            tmpSchemaFile = tmpZipFile2;
            throw tmpZipFile2;
        }
        finally {
            if (is != null) {
                if (tmpSchemaFile != null) {
                    try {
                        is.close();
                    }
                    catch (Throwable tmpZipFile2) {
                        ((Throwable)tmpSchemaFile).addSuppressed(tmpZipFile2);
                    }
                } else {
                    is.close();
                }
            }
        }
        schema = this.updateSchema(schema, pipelineModel);
        InputStream is2 = this.open(this.getInputCsvPath());
        tmpZipFile2 = null;
        try {
            File tmpCsvFile = SparkMLEncoderBatch.toTmpFile(is2, this.getDataset(), ".csv");
            tmpResources.add(tmpCsvFile);
            inputDataset = DatasetUtil.loadCsv(sparkSession, tmpCsvFile);
        }
        catch (Throwable tmpCsvFile) {
            tmpZipFile2 = tmpCsvFile;
            throw tmpCsvFile;
        }
        finally {
            if (is2 != null) {
                if (tmpZipFile2 != null) {
                    try {
                        is2.close();
                    }
                    catch (Throwable tmpCsvFile) {
                        ((Throwable)tmpZipFile2).addSuppressed(tmpCsvFile);
                    }
                } else {
                    is2.close();
                }
            }
        }
        inputDataset = DatasetUtil.castColumns(inputDataset, schema);
        Map options = this.getOptions();
        PMMLBuilder pmmlBuilder = new PMMLBuilder(schema, pipelineModel).putOptions(options);
        Dataset<Row> verificationDataset = this.getVerificationDataset(inputDataset);
        if (verificationDataset != null) {
            Equivalence equivalence = this.getEquivalence();
            double precision = 1.0E-14;
            double zeroThreshold = 1.0E-14;
            if (equivalence instanceof PMMLEquivalence) {
                PMMLEquivalence pmmlEquivalence = (PMMLEquivalence)equivalence;
                precision = pmmlEquivalence.getPrecision();
                zeroThreshold = pmmlEquivalence.getZeroThreshold();
            }
            pmmlBuilder = pmmlBuilder.verify(verificationDataset, precision, zeroThreshold);
        }
        PMML pmml = pmmlBuilder.build();
        this.validatePMML(pmml);
        for (File tmpResource : tmpResources) {
            MoreFiles.deleteRecursively((Path)tmpResource.toPath(), (RecursiveDeleteOption[])new RecursiveDeleteOption[0]);
        }
        return pmml;
    }

    protected StructType updateSchema(StructType schema, PipelineModel pipelineModel) {
        return schema;
    }

    private static File toTmpFile(InputStream is, String prefix, String suffix) throws IOException {
        File tmpFile = File.createTempFile(prefix, suffix);
        try (FileOutputStream os = new FileOutputStream(tmpFile);){
            ByteStreams.copy((InputStream)is, (OutputStream)os);
        }
        return tmpFile;
    }
}

