/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.testing;

import com.google.common.base.Equivalence;
import com.google.common.io.ByteStreams;
import com.google.common.io.CharStreams;
import com.google.common.io.MoreFiles;
import com.google.common.io.RecursiveDeleteOption;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.Serializable;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.testing.ModelEncoderBatch;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.PMMLEquivalence;
import org.jpmml.sparkml.PMMLBuilder;
import org.jpmml.sparkml.ZipUtil;
import org.jpmml.sparkml.testing.SparkMLEncoderBatchTest;

public abstract class SparkMLEncoderBatch
extends ModelEncoderBatch {
    public SparkMLEncoderBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence) {
        super(algorithm, dataset, columnFilter, equivalence);
    }

    public abstract SparkMLEncoderBatchTest getArchiveBatchTest();

    public List<Map<String, Object>> getOptionsMatrix() {
        LinkedHashMap<String, Integer> options = new LinkedHashMap<String, Integer>();
        options.put("lookup_threshold", 5);
        return Collections.singletonList(options);
    }

    public String getSchemaJsonPath() {
        return "/schema/" + this.getDataset() + ".json";
    }

    public String getPipelineZipPath() {
        return "/pipeline/" + this.getAlgorithm() + this.getDataset() + ".zip";
    }

    public Dataset<Row> getVerificationDataset(StructType schema, Dataset<Row> inputDataset) {
        List<StructField> fields = Arrays.asList(schema.fields());
        for (StructField field : fields) {
            Column column = inputDataset.apply(field.name()).cast(field.dataType());
            inputDataset = inputDataset.withColumn("tmp_" + field.name(), column).drop(field.name()).withColumnRenamed("tmp_" + field.name(), field.name());
        }
        return inputDataset.sample(false, 0.05, 63317L);
    }

    public PMML getPMML() throws Exception {
        Dataset inputDataset;
        PipelineModel pipelineModel;
        Serializable tmpZipFile2;
        StructType schema;
        Object json;
        SparkMLEncoderBatchTest archiveBatchTest = this.getArchiveBatchTest();
        SparkSession sparkSession = archiveBatchTest.getSparkSession();
        if (sparkSession == null) {
            throw new IllegalStateException();
        }
        ArrayList<File> tmpResources = new ArrayList<File>();
        try (InputStream is = this.open(this.getSchemaJsonPath());){
            json = CharStreams.toString((Readable)new InputStreamReader(is, "UTF-8"));
            schema = (StructType)DataType.fromJson((String)json);
        }
        InputStream is = this.open(this.getPipelineZipPath());
        json = null;
        try {
            tmpZipFile2 = File.createTempFile(this.getAlgorithm() + this.getDataset(), ".zip");
            tmpResources.add((File)tmpZipFile2);
            try (FileOutputStream os = new FileOutputStream((File)tmpZipFile2);){
                ByteStreams.copy((InputStream)is, (OutputStream)os);
            }
            File tmpPipelineDir = File.createTempFile(this.getAlgorithm() + this.getDataset(), "");
            if (!tmpPipelineDir.delete()) {
                throw new IOException();
            }
            tmpResources.add(tmpPipelineDir);
            ZipUtil.uncompress((File)tmpZipFile2, tmpPipelineDir);
            MLReader mlReader = PipelineModel.read();
            mlReader.session(sparkSession);
            pipelineModel = (PipelineModel)mlReader.load(tmpPipelineDir.getAbsolutePath());
        }
        catch (Throwable tmpZipFile2) {
            json = tmpZipFile2;
            throw tmpZipFile2;
        }
        finally {
            if (is != null) {
                if (json != null) {
                    try {
                        is.close();
                    }
                    catch (Throwable tmpZipFile2) {
                        ((Throwable)json).addSuppressed(tmpZipFile2);
                    }
                } else {
                    is.close();
                }
            }
        }
        InputStream is2 = this.open(this.getInputCsvPath());
        tmpZipFile2 = null;
        try {
            File tmpCsvFile = File.createTempFile(this.getDataset(), ".csv");
            tmpResources.add(tmpCsvFile);
            try (FileOutputStream os = new FileOutputStream(tmpCsvFile);){
                ByteStreams.copy((InputStream)is2, (OutputStream)os);
            }
            inputDataset = sparkSession.read().format("csv").option("header", true).option("inferSchema", false).load(tmpCsvFile.getAbsolutePath());
        }
        catch (Throwable tmpCsvFile) {
            tmpZipFile2 = tmpCsvFile;
            throw tmpCsvFile;
        }
        finally {
            if (is2 != null) {
                if (tmpZipFile2 != null) {
                    try {
                        is2.close();
                    }
                    catch (Throwable tmpCsvFile) {
                        ((Throwable)tmpZipFile2).addSuppressed(tmpCsvFile);
                    }
                } else {
                    is2.close();
                }
            }
        }
        Map options = this.getOptions();
        PMMLBuilder pmmlBuilder = new PMMLBuilder(schema, pipelineModel).putOptions(options);
        Dataset<Row> verificationDataset = this.getVerificationDataset(schema, (Dataset<Row>)inputDataset);
        if (verificationDataset != null) {
            Equivalence equivalence = this.getEquivalence();
            double precision = 1.0E-14;
            double zeroThreshold = 1.0E-14;
            if (equivalence instanceof PMMLEquivalence) {
                PMMLEquivalence pmmlEquivalence = (PMMLEquivalence)equivalence;
                precision = pmmlEquivalence.getPrecision();
                zeroThreshold = pmmlEquivalence.getZeroThreshold();
            }
            pmmlBuilder = pmmlBuilder.verify(verificationDataset, precision, zeroThreshold);
        }
        PMML pmml = pmmlBuilder.build();
        this.validatePMML(pmml);
        for (File tmpResource : tmpResources) {
            MoreFiles.deleteRecursively((Path)tmpResource.toPath(), (RecursiveDeleteOption[])new RecursiveDeleteOption[0]);
        }
        return pmml;
    }
}

