/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml;

import com.google.common.io.CharStreams;
import com.google.common.io.Files;
import com.google.common.io.MoreFiles;
import com.google.common.io.RecursiveDeleteOption;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.Collections;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrameWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalog.Catalog;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.execution.QueryExecution;
import org.apache.spark.sql.types.AtomicType;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.FractionalType;
import org.apache.spark.sql.types.IntegralType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataType;

public class DatasetUtil {
    private static final AtomicInteger ID = new AtomicInteger(1);

    private DatasetUtil() {
    }

    public static StructType loadSchema(File file) throws IOException {
        try (FileInputStream is = new FileInputStream(file);){
            String json = CharStreams.toString((Readable)new InputStreamReader((InputStream)is, "UTF-8"));
            StructType structType = (StructType)StructType.fromJson((String)json);
            return structType;
        }
    }

    public static void storeSchema(Dataset<Row> dataset, File file) throws IOException {
        DatasetUtil.storeSchema(dataset.schema(), file);
    }

    public static void storeSchema(StructType schema, File file) throws IOException {
        try (FileOutputStream os = new FileOutputStream(file);){
            String string = schema.json();
            ((OutputStream)os).write(string.getBytes("UTF-8"));
        }
    }

    public static Dataset<Row> loadCsv(SparkSession sparkSession, File file) throws IOException {
        return sparkSession.read().format("csv").option("header", true).option("inferSchema", true).load(file.getAbsolutePath());
    }

    public static void storeCsv(Dataset<Row> dataset, File file) throws IOException {
        File tmpDir = File.createTempFile("Dataset", "");
        if (!tmpDir.delete()) {
            throw new IOException();
        }
        dataset = dataset.coalesce(1);
        DataFrameWriter writer = dataset.write().format("csv").option("header", "true");
        writer.save(tmpDir.getAbsolutePath());
        FileFilter csvFileFilter = new FileFilter(){

            @Override
            public boolean accept(File file) {
                String name = file.getName();
                return name.endsWith(".csv");
            }
        };
        File[] csvFiles = tmpDir.listFiles(csvFileFilter);
        if (csvFiles.length != 1) {
            throw new IOException();
        }
        Files.copy((File)csvFiles[0], (File)file);
        MoreFiles.deleteRecursively((Path)tmpDir.toPath(), (RecursiveDeleteOption[])new RecursiveDeleteOption[0]);
    }

    public static Dataset<Row> castColumn(Dataset<Row> dataset, String name, org.apache.spark.sql.types.DataType sparkDataType) {
        Column column = dataset.apply(name).cast(sparkDataType);
        String tmpName = "tmp_" + name;
        return dataset.withColumn(tmpName, column).drop(name).withColumnRenamed(tmpName, name);
    }

    public static Dataset<Row> castColumns(Dataset<Row> dataset, StructType schema) {
        StructField[] fields;
        StructType prevSchema = dataset.schema();
        for (StructField field : fields = schema.fields()) {
            StructField prevField;
            try {
                prevField = prevSchema.apply(field.name());
            }
            catch (IllegalArgumentException iae) {
                continue;
            }
            if (Objects.equals(field.dataType(), prevField.dataType())) continue;
            dataset = DatasetUtil.castColumn(dataset, field.name(), field.dataType());
        }
        return dataset;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static LogicalPlan createAnalyzedLogicalPlan(SparkSession sparkSession, StructType schema, String statement) {
        String tableName = "sql2pmml_" + ID.getAndIncrement();
        statement = statement.replace("__THIS__", tableName);
        Dataset dataset = sparkSession.createDataFrame(Collections.emptyList(), schema);
        dataset.createOrReplaceTempView(tableName);
        try {
            QueryExecution queryExecution = sparkSession.sql(statement).queryExecution();
            LogicalPlan logicalPlan = queryExecution.analyzed();
            return logicalPlan;
        }
        finally {
            Catalog catalog = sparkSession.catalog();
            catalog.dropTempView(tableName);
        }
    }

    public static DataType translateDataType(org.apache.spark.sql.types.DataType sparkDataType) {
        if (sparkDataType instanceof AtomicType) {
            return DatasetUtil.translateAtomicType((AtomicType)sparkDataType);
        }
        throw new IllegalArgumentException("Expected atomic data type, got " + sparkDataType.typeName() + " data type");
    }

    public static DataType translateAtomicType(AtomicType atomicType) {
        if (atomicType instanceof StringType) {
            return DataType.STRING;
        }
        if (atomicType instanceof IntegralType) {
            return DatasetUtil.translateIntegralType((IntegralType)atomicType);
        }
        if (atomicType instanceof FractionalType) {
            return DatasetUtil.translateFractionalType((FractionalType)atomicType);
        }
        if (atomicType instanceof BooleanType) {
            return DataType.BOOLEAN;
        }
        throw new IllegalArgumentException("Expected string, integral, fractional or boolean data type, got " + atomicType.typeName() + " data type");
    }

    public static DataType translateIntegralType(IntegralType integralType) {
        return DataType.INTEGER;
    }

    public static DataType translateFractionalType(FractionalType fractionalType) {
        if (fractionalType instanceof FloatType) {
            return DataType.FLOAT;
        }
        if (fractionalType instanceof DoubleType) {
            return DataType.DOUBLE;
        }
        throw new IllegalArgumentException("Expected float or double data type, got " + fractionalType.typeName() + " data type");
    }
}

