/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.testing;

import com.google.common.base.Equivalence;
import com.google.common.io.ByteStreams;
import com.google.common.io.MoreFiles;
import com.google.common.io.RecursiveDeleteOption;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.testing.ModelEncoderBatch;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.PMMLEquivalence;
import org.jpmml.sparkml.ArchiveUtil;
import org.jpmml.sparkml.DatasetUtil;
import org.jpmml.sparkml.PMMLBuilder;
import org.jpmml.sparkml.PipelineModelUtil;
import org.jpmml.sparkml.testing.SparkMLEncoderBatchTest;

public abstract class SparkMLEncoderBatch
extends ModelEncoderBatch {
    public SparkMLEncoderBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence) {
        super(algorithm, dataset, columnFilter, equivalence);
    }

    public abstract SparkMLEncoderBatchTest getArchiveBatchTest();

    public List<Map<String, Object>> getOptionsMatrix() {
        LinkedHashMap options = new LinkedHashMap();
        return Collections.singletonList(options);
    }

    public String getPipelineModelZipPath() {
        return "/pipeline/" + this.getAlgorithm() + this.getDataset() + ".zip";
    }

    public String getSchemaJsonPath() {
        return "/schema/" + this.getDataset() + ".json";
    }

    public Dataset<Row> getVerificationDataset(Dataset<Row> inputDataset) {
        return inputDataset.sample(false, 0.05, 63317L);
    }

    public PMML getPMML() throws Exception {
        SparkMLEncoderBatchTest archiveBatchTest = this.getArchiveBatchTest();
        SparkSession sparkSession = archiveBatchTest.getSparkSession();
        if (sparkSession == null) {
            throw new IllegalStateException();
        }
        ArrayList<File> tmpResources = new ArrayList<File>();
        StructType schema = this.loadSchema(sparkSession, tmpResources);
        PipelineModel pipelineModel = this.loadPipelineModel(sparkSession, tmpResources);
        schema = this.updateSchema(schema, pipelineModel);
        Dataset<Row> inputDataset = this.loadInput(sparkSession, tmpResources);
        inputDataset = DatasetUtil.castColumns(inputDataset, schema);
        Map options = this.getOptions();
        PMMLBuilder pmmlBuilder = new PMMLBuilder(schema, pipelineModel).putOptions(options);
        Dataset<Row> verificationDataset = this.getVerificationDataset(inputDataset);
        if (verificationDataset != null) {
            Equivalence equivalence = this.getEquivalence();
            double precision = 1.0E-14;
            double zeroThreshold = 1.0E-14;
            if (equivalence instanceof PMMLEquivalence) {
                PMMLEquivalence pmmlEquivalence = (PMMLEquivalence)equivalence;
                precision = pmmlEquivalence.getPrecision();
                zeroThreshold = pmmlEquivalence.getZeroThreshold();
            }
            pmmlBuilder = pmmlBuilder.verify(verificationDataset, precision, zeroThreshold);
        }
        PMML pmml = pmmlBuilder.build();
        this.validatePMML(pmml);
        for (File tmpResource : tmpResources) {
            MoreFiles.deleteRecursively((Path)tmpResource.toPath(), (RecursiveDeleteOption[])new RecursiveDeleteOption[]{RecursiveDeleteOption.ALLOW_INSECURE});
        }
        return pmml;
    }

    protected StructType loadSchema(SparkSession sparkSession, List<File> tmpResources) throws IOException {
        try (InputStream is = this.open(this.getSchemaJsonPath());){
            File tmpSchemaFile = SparkMLEncoderBatch.toTmpFile(is, this.getDataset(), ".json");
            tmpResources.add(tmpSchemaFile);
            StructType structType = DatasetUtil.loadSchema(tmpSchemaFile);
            return structType;
        }
    }

    protected PipelineModel loadPipelineModel(SparkSession sparkSession, List<File> tmpResources) throws IOException {
        try (InputStream is = this.open(this.getPipelineModelZipPath());){
            File tmpZipFile = SparkMLEncoderBatch.toTmpFile(is, this.getAlgorithm() + this.getDataset(), ".zip");
            tmpResources.add(tmpZipFile);
            File tmpPipelineModelDir = ArchiveUtil.uncompress(tmpZipFile);
            tmpResources.add(tmpPipelineModelDir);
            PipelineModel pipelineModel = PipelineModelUtil.load(sparkSession, tmpPipelineModelDir);
            return pipelineModel;
        }
    }

    protected StructType updateSchema(StructType schema, PipelineModel pipelineModel) {
        return schema;
    }

    protected Dataset<Row> loadInput(SparkSession sparkSession, List<File> tmpResources) throws IOException {
        try (InputStream is = this.open(this.getInputCsvPath());){
            File tmpCsvFile = SparkMLEncoderBatch.toTmpFile(is, this.getDataset(), ".csv");
            tmpResources.add(tmpCsvFile);
            Dataset<Row> dataset = DatasetUtil.loadCsv(sparkSession, tmpCsvFile);
            return dataset;
        }
    }

    protected static File toTmpFile(InputStream is, String prefix, String suffix) throws IOException {
        File tmpFile = File.createTempFile(prefix, suffix);
        try (FileOutputStream os = new FileOutputStream(tmpFile);){
            ByteStreams.copy((InputStream)is, (OutputStream)os);
        }
        return tmpFile;
    }
}

