/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.Reader;
import java.nio.charset.MalformedInputException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution;
import org.kie.kogito.explainability.model.NumericFeatureDistribution;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

public class DataUtils {
    private DataUtils() {
    }

    public static double[] generateData(double mean, double stdDeviation, int size, Random random) {
        double[] data = new double[size];
        for (int i = 0; i < size; ++i) {
            data[i] = random.nextGaussian() * stdDeviation + mean;
        }
        double generatedDataMean = DataUtils.getMean(data);
        double generatedDataStdDev = DataUtils.getStdDev(data, generatedDataMean);
        double newStdDeviation = generatedDataStdDev != 0.0 ? stdDeviation / generatedDataStdDev : stdDeviation;
        int i = 0;
        while (i < size) {
            int n = i++;
            data[n] = data[n] * newStdDeviation;
        }
        double newMean = generatedDataStdDev != 0.0 ? generatedDataMean * stdDeviation / generatedDataStdDev : generatedDataMean * stdDeviation;
        int i2 = 0;
        while (i2 < size) {
            int n = i2++;
            data[n] = data[n] + (mean - newMean);
        }
        return data;
    }

    public static double getMean(double[] data) {
        double m = 0.0;
        for (double datum : data) {
            m += datum;
        }
        return m /= (double)data.length;
    }

    public static double getStdDev(double[] data, double mean) {
        double d = 0.0;
        for (double datum : data) {
            d += Math.pow(datum - mean, 2.0);
        }
        d /= (double)data.length;
        d = Math.sqrt(d);
        return d;
    }

    public static double[] generateSamples(double min, double max, int size) {
        double[] data = new double[size];
        double val = min;
        double sum = max / (double)size;
        for (int i = 0; i < size; ++i) {
            data[i] = val;
            val += sum;
        }
        return data;
    }

    public static List<Feature> doublesToFeatures(double[] inputs) {
        return DoubleStream.of(inputs).mapToObj(DataUtils::doubleToFeature).collect(Collectors.toList());
    }

    static Feature doubleToFeature(double d) {
        return FeatureFactory.newNumericalFeature(String.valueOf(d), d);
    }

    public static List<Feature> perturbFeatures(List<Feature> originalFeatures, PerturbationContext perturbationContext) {
        return DataUtils.perturbFeatures(originalFeatures, perturbationContext, Collections.emptyMap());
    }

    public static List<Feature> perturbFeatures(List<Feature> originalFeatures, PerturbationContext perturbationContext, Map<String, FeatureDistribution> featureDistributionsMap) {
        ArrayList<Feature> newFeatures = new ArrayList<Feature>(originalFeatures);
        if (!newFeatures.isEmpty()) {
            int lowerBound = (int)Math.min((double)perturbationContext.getNoOfPerturbations(), 0.5 * (double)newFeatures.size());
            int upperBound = (int)Math.max((double)perturbationContext.getNoOfPerturbations(), 0.5 * (double)newFeatures.size());
            upperBound = Math.min(upperBound, newFeatures.size());
            lowerBound = Math.max(1, lowerBound);
            int perturbationSize = 0;
            if (lowerBound == upperBound) {
                perturbationSize = lowerBound;
            } else if (upperBound > lowerBound) {
                perturbationSize = perturbationContext.getRandom().ints(1L, lowerBound, 1 + upperBound).findFirst().orElse(1);
            }
            if (perturbationSize > 0) {
                int[] indexesToBePerturbed;
                for (int index : indexesToBePerturbed = perturbationContext.getRandom().ints(0, newFeatures.size()).distinct().limit(perturbationSize).toArray()) {
                    Feature feature = (Feature)newFeatures.get(index);
                    Value newValue = featureDistributionsMap.containsKey(feature.getName()) ? featureDistributionsMap.get(feature.getName()).sample() : feature.getType().perturb(feature.getValue(), perturbationContext);
                    Feature perturbedFeature = FeatureFactory.copyOf(feature, newValue);
                    newFeatures.set(index, perturbedFeature);
                }
            }
        }
        return newFeatures;
    }

    public static List<Feature> dropFeature(List<Feature> features, Feature target) {
        ArrayList<Feature> newList = new ArrayList<Feature>(features.size());
        for (Feature sourceFeature : features) {
            Feature f;
            String sourceFeatureName = sourceFeature.getName();
            Type sourceFeatureType = sourceFeature.getType();
            Value sourceFeatureValue = sourceFeature.getValue();
            if (target.getName().equals(sourceFeatureName)) {
                if (target.getType().equals((Object)sourceFeatureType) && target.getValue().equals(sourceFeatureValue)) {
                    Value droppedValue = sourceFeatureType.drop(sourceFeatureValue);
                    f = FeatureFactory.copyOf(sourceFeature, droppedValue);
                } else {
                    f = DataUtils.dropOnLinearizedFeatures(target, sourceFeature);
                }
            } else if (Type.COMPOSITE.equals((Object)sourceFeatureType)) {
                List nestedFeatures = (List)sourceFeatureValue.getUnderlyingObject();
                f = FeatureFactory.newCompositeFeature(sourceFeatureName, DataUtils.dropFeature(nestedFeatures, target));
            } else {
                f = FeatureFactory.copyOf(sourceFeature, sourceFeatureValue);
            }
            newList.add(f);
        }
        return newList;
    }

    protected static Feature dropOnLinearizedFeatures(Feature target, Feature sourceFeature) {
        Feature f = null;
        List<Feature> linearizedFeatures = DataUtils.getLinearizedFeatures(List.of(sourceFeature));
        int i = 0;
        for (Feature linearizedFeature : linearizedFeatures) {
            if (target.getValue().equals(linearizedFeature.getValue())) {
                linearizedFeatures.set(i, FeatureFactory.copyOf(linearizedFeature, linearizedFeature.getType().drop(target.getValue())));
                f = FeatureFactory.newCompositeFeature(target.getName(), linearizedFeatures);
                break;
            }
            ++i;
        }
        if (f == null) {
            f = FeatureFactory.copyOf(sourceFeature, sourceFeature.getValue());
        }
        return f;
    }

    public static double hammingDistance(double[] x, double[] y) {
        if (x.length != y.length) {
            return Double.NaN;
        }
        double h = 0.0;
        for (int i = 0; i < x.length; ++i) {
            if (x[i] == y[i]) continue;
            h += 1.0;
        }
        return h;
    }

    public static double hammingDistance(String x, String y) {
        if (x.length() != y.length()) {
            return Double.NaN;
        }
        double h = 0.0;
        for (int i = 0; i < x.length(); ++i) {
            if (x.charAt(i) == y.charAt(i)) continue;
            h += 1.0;
        }
        return h;
    }

    public static double euclideanDistance(double[] x, double[] y) {
        if (x.length != y.length) {
            return Double.NaN;
        }
        double e = 0.0;
        for (int i = 0; i < x.length; ++i) {
            e += Math.pow(x[i] - y[i], 2.0);
        }
        return Math.sqrt(e);
    }

    public static double gaussianKernel(double x, double mu, double sigma) {
        return Math.exp(-Math.pow((x - mu) / sigma, 2.0) / 2.0) / (sigma * Math.sqrt(Math.PI * 2));
    }

    public static double exponentialSmoothingKernel(double x, double width) {
        return Math.sqrt(Math.exp(-Math.pow(x, 2.0) / Math.pow(width, 2.0)));
    }

    public static DataDistribution generateRandomDataDistribution(int noOfFeatures, int distributionSize, Random random) {
        LinkedList<FeatureDistribution> featureDistributions = new LinkedList<FeatureDistribution>();
        for (int i = 0; i < noOfFeatures; ++i) {
            double[] doubles = DataUtils.generateData(random.nextDouble(), random.nextDouble(), distributionSize, random);
            Feature feature = FeatureFactory.newNumericalFeature("f_" + i, Double.NaN);
            NumericFeatureDistribution featureDistribution = new NumericFeatureDistribution(feature, doubles);
            featureDistributions.add(featureDistribution);
        }
        return new IndependentFeaturesDataDistribution(featureDistributions);
    }

    public static List<PredictionInput> linearizeInputs(List<PredictionInput> predictionInputs) {
        LinkedList<PredictionInput> newInputs = new LinkedList<PredictionInput>();
        for (PredictionInput predictionInput : predictionInputs) {
            List<Feature> originalFeatures = predictionInput.getFeatures();
            List<Feature> flattenedFeatures = DataUtils.getLinearizedFeatures(originalFeatures);
            newInputs.add(new PredictionInput(flattenedFeatures));
        }
        return newInputs;
    }

    public static List<Feature> getLinearizedFeatures(List<Feature> originalFeatures) {
        LinkedList<Feature> flattenedFeatures = new LinkedList<Feature>();
        for (Feature f : originalFeatures) {
            DataUtils.linearizeFeature(flattenedFeatures, f);
        }
        return flattenedFeatures;
    }

    private static void linearizeFeature(List<Feature> flattenedFeatures, Feature f) {
        if (Type.UNDEFINED.equals((Object)f.getType())) {
            if (f.getValue().getUnderlyingObject() instanceof Feature) {
                DataUtils.linearizeFeature(flattenedFeatures, (Feature)f.getValue().getUnderlyingObject());
            } else {
                flattenedFeatures.add(f);
            }
        } else if (Type.COMPOSITE.equals((Object)f.getType())) {
            if (f.getValue().getUnderlyingObject() instanceof List) {
                List features = (List)f.getValue().getUnderlyingObject();
                for (Feature feature : features) {
                    DataUtils.linearizeFeature(flattenedFeatures, feature);
                }
            } else {
                flattenedFeatures.add(f);
            }
        } else {
            flattenedFeatures.add(f);
        }
    }

    public static List<Prediction> getPredictions(List<PredictionInput> inputs, List<PredictionOutput> os) {
        return IntStream.range(0, os.size()).mapToObj(i -> new SimplePrediction((PredictionInput)inputs.get(i), (PredictionOutput)os.get(i))).collect(Collectors.toList());
    }

    public static <T> List<T> sampleWithReplacement(List<T> values, int sampleSize, Random random) {
        if (sampleSize <= 0 || values.isEmpty()) {
            return Collections.emptyList();
        }
        return random.ints(sampleSize, 0, values.size()).mapToObj(values::get).collect(Collectors.toList());
    }

    public static List<Feature> replaceFeatures(Feature featureToUse, List<Feature> existingFeatures) {
        ArrayList<Feature> newFeatures = new ArrayList<Feature>();
        for (Feature f : existingFeatures) {
            Feature newFeature;
            if (f.getName().equals(featureToUse.getName())) {
                newFeature = FeatureFactory.copyOf(f, featureToUse.getValue());
            } else if (Type.COMPOSITE == f.getType()) {
                List elements = (List)f.getValue().getUnderlyingObject();
                newFeature = FeatureFactory.newCompositeFeature(f.getName(), DataUtils.replaceFeatures(featureToUse, elements));
            } else {
                newFeature = FeatureFactory.copyOf(f, f.getValue());
            }
            newFeatures.add(newFeature);
        }
        return newFeatures;
    }

    public static void toCSV(PartialDependenceGraph partialDependenceGraph, Path path) throws IOException {
        try (BufferedWriter writer = Files.newBufferedWriter(path, new OpenOption[0]);){
            List<Value> xAxis = partialDependenceGraph.getX();
            List<Value> yAxis = partialDependenceGraph.getY();
            CSVFormat format = CSVFormat.DEFAULT.withHeader(new String[]{partialDependenceGraph.getFeature().getName(), partialDependenceGraph.getOutput().getName()});
            CSVPrinter printer = new CSVPrinter((Appendable)writer, format);
            for (int i = 0; i < xAxis.size(); ++i) {
                printer.printRecord(new Object[]{xAxis.get(i).asString(), yAxis.get(i).asString()});
            }
        }
    }

    public static DataDistribution readCSV(Path file, List<Type> schema) throws IOException {
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        try (BufferedReader reader = Files.newBufferedReader(file);){
            CSVParser records = CSVFormat.RFC4180.withFirstRecordAsHeader().parse((Reader)reader);
            for (CSVRecord record : records) {
                int size = record.size();
                if (schema.size() == size) {
                    ArrayList<Feature> features = new ArrayList<Feature>();
                    for (int i = 0; i < size; ++i) {
                        String s = record.get(i);
                        Type type = schema.get(i);
                        features.add(new Feature((String)record.getParser().getHeaderNames().get(i), type, new Value(s)));
                    }
                    inputs.add(new PredictionInput(features));
                    continue;
                }
                throw new MalformedInputException(size);
            }
        }
        return new PredictionInputsDataDistribution(inputs);
    }

    public static Map<String, FeatureDistribution> boostrapFeatureDistributions(DataDistribution dataDistribution, PerturbationContext perturbationContext, int featureDistributionSize, int draws, int sampleSize, Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap) {
        HashMap<String, FeatureDistribution> featureDistributions = new HashMap<String, FeatureDistribution>();
        for (FeatureDistribution featureDistribution : dataDistribution.asFeatureDistributions()) {
            double[] filteredData;
            Feature feature = featureDistribution.getFeature();
            if (!Type.NUMBER.equals((Object)feature.getType())) continue;
            List<Value> values = featureDistribution.getAllSamples();
            double[] means = new double[draws];
            double[] stdDevs = new double[draws];
            double[] mins = new double[draws];
            double[] maxs = new double[draws];
            for (int i = 0; i < draws; ++i) {
                List<Value> sampledValues = DataUtils.sampleWithReplacement(values, sampleSize, perturbationContext.getRandom());
                double[] data = sampledValues.stream().mapToDouble(Value::asNumber).toArray();
                double mean = DataUtils.getMean(data);
                double stdDev = Math.pow(DataUtils.getStdDev(data, mean), 2.0);
                double min = Arrays.stream(data).min().orElse(Double.MIN_VALUE);
                double max = Arrays.stream(data).max().orElse(Double.MAX_VALUE);
                means[i] = mean;
                stdDevs[i] = stdDev;
                mins[i] = min;
                maxs[i] = max;
            }
            double finalMean = DataUtils.getMean(means);
            double finalStdDev = Math.sqrt(DataUtils.getMean(stdDevs));
            double finalMin = DataUtils.getMean(mins);
            double finalMax = DataUtils.getMean(maxs);
            double[] doubles = DataUtils.generateData(finalMean, finalStdDev, featureDistributionSize, perturbationContext.getRandom());
            double[] boundedData = Arrays.stream(doubles).map(d -> Math.min(Math.max(d, finalMin), finalMax)).toArray();
            HighScoreNumericFeatureZones highScoreNumericFeatureZones = numericFeatureZonesMap.get(feature.getName());
            double[] finaldata = highScoreNumericFeatureZones != null ? ((filteredData = DoubleStream.of(boundedData).filter(highScoreNumericFeatureZones::test).toArray()).length > featureDistributionSize / 2 ? filteredData : boundedData) : boundedData;
            NumericFeatureDistribution numericFeatureDistribution = new NumericFeatureDistribution(feature, finaldata);
            featureDistributions.put(feature.getName(), numericFeatureDistribution);
        }
        return featureDistributions;
    }

    public static List<Prediction> getScoreSortedPredictions(String outputName, PredictionProvider predictionProvider, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionInput> inputs = dataDistribution.getAllSamples();
        List<PredictionOutput> predictionOutputs = predictionProvider.predictAsync(inputs).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
        List<Prediction> predictions = DataUtils.getPredictions(inputs, predictionOutputs);
        return predictions.stream().sorted((p1, p2) -> {
            Optional<Output> optionalOutput1 = p1.getOutput().getByName(outputName);
            Optional<Output> optionalOutput2 = p2.getOutput().getByName(outputName);
            if (optionalOutput1.isPresent() && optionalOutput2.isPresent()) {
                Output o1 = optionalOutput1.get();
                Output o2 = optionalOutput2.get();
                return Double.compare(o2.getScore(), o1.getScore());
            }
            return 0;
        }).collect(Collectors.toList());
    }

    public static List<Prediction> getScoreSortedPredictions(PredictionProvider predictionProvider, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionInput> inputs = dataDistribution.getAllSamples();
        List<PredictionOutput> predictionOutputs = predictionProvider.predictAsync(inputs).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
        List<Prediction> predictions = DataUtils.getPredictions(inputs, predictionOutputs);
        return predictions.stream().sorted((p1, p2) -> {
            List<Output> o1 = p1.getOutput().getOutputs();
            List<Output> o2 = p2.getOutput().getOutputs();
            return Double.compare(o2.stream().mapToDouble(Output::getScore).sum(), o1.stream().mapToDouble(Output::getScore).sum());
        }).collect(Collectors.toList());
    }

    public static String textify(PredictionInput input) {
        StringBuilder text = new StringBuilder();
        for (Feature f : DataUtils.getLinearizedFeatures(input.getFeatures())) {
            if (text.length() > 0) {
                text.append(' ');
            }
            text.append(f.getValue().asString());
        }
        return text.toString();
    }
}

