/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution;
import org.kie.kogito.explainability.model.NumericFeatureDistribution;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

public class DataUtils {
    private DataUtils() {
    }

    public static double[] generateData(double mean, double stdDeviation, int size, Random random) {
        double[] data = new double[size];
        for (int i = 0; i < size; ++i) {
            data[i] = random.nextGaussian() * stdDeviation + mean;
        }
        double generatedDataMean = DataUtils.getMean(data);
        double generatedDataStdDev = DataUtils.getStdDev(data, generatedDataMean);
        double newStdDeviation = generatedDataStdDev != 0.0 ? stdDeviation / generatedDataStdDev : stdDeviation;
        int i = 0;
        while (i < size) {
            int n = i++;
            data[n] = data[n] * newStdDeviation;
        }
        double newMean = generatedDataStdDev != 0.0 ? generatedDataMean * stdDeviation / generatedDataStdDev : generatedDataMean * stdDeviation;
        int i2 = 0;
        while (i2 < size) {
            int n = i2++;
            data[n] = data[n] + (mean - newMean);
        }
        return data;
    }

    public static double getMean(double[] data) {
        double m = 0.0;
        for (double datum : data) {
            m += datum;
        }
        return m /= (double)data.length;
    }

    public static double getStdDev(double[] data, double mean) {
        double d = 0.0;
        for (double datum : data) {
            d += Math.pow(datum - mean, 2.0);
        }
        d /= (double)data.length;
        d = Math.sqrt(d);
        return d;
    }

    public static double[] generateSamples(double min, double max, int size) {
        double[] data = new double[size];
        double val = min;
        double sum = max / (double)size;
        for (int i = 0; i < size; ++i) {
            data[i] = val;
            val += sum;
        }
        return data;
    }

    public static List<Feature> doublesToFeatures(double[] inputs) {
        return DoubleStream.of(inputs).mapToObj(DataUtils::doubleToFeature).collect(Collectors.toList());
    }

    static Feature doubleToFeature(double d) {
        return FeatureFactory.newNumericalFeature(String.valueOf(d), d);
    }

    public static List<Feature> perturbFeatures(List<Feature> originalFeatures, PerturbationContext perturbationContext) {
        ArrayList<Feature> newFeatures = new ArrayList<Feature>(originalFeatures);
        if (!newFeatures.isEmpty()) {
            int lowerBound = (int)Math.min((double)perturbationContext.getNoOfPerturbations(), 0.5 * (double)newFeatures.size());
            int upperBound = (int)Math.max((double)perturbationContext.getNoOfPerturbations(), 0.5 * (double)newFeatures.size());
            upperBound = Math.min(upperBound, newFeatures.size());
            lowerBound = Math.max(1, lowerBound);
            int perturbationSize = 0;
            if (lowerBound == upperBound) {
                perturbationSize = lowerBound;
            } else if (upperBound > lowerBound) {
                perturbationSize = perturbationContext.getRandom().ints(1L, lowerBound, 1 + upperBound).findFirst().orElse(1);
            }
            if (perturbationSize > 0) {
                int[] indexesToBePerturbed;
                for (int index : indexesToBePerturbed = perturbationContext.getRandom().ints(0, newFeatures.size()).distinct().limit(perturbationSize).toArray()) {
                    Feature feature = (Feature)newFeatures.get(index);
                    Feature perturbedFeature = FeatureFactory.copyOf(feature, feature.getType().perturb(feature.getValue(), perturbationContext));
                    newFeatures.set(index, perturbedFeature);
                }
            }
        }
        return newFeatures;
    }

    public static List<Feature> dropFeature(List<Feature> features, Feature target) {
        ArrayList<Feature> newList = new ArrayList<Feature>(features.size());
        for (Feature sourceFeature : features) {
            Feature f;
            String sourceFeatureName = sourceFeature.getName();
            Type sourceFeatureType = sourceFeature.getType();
            Value sourceFeatureValue = sourceFeature.getValue();
            if (target.getName().equals(sourceFeatureName)) {
                if (target.getType().equals((Object)sourceFeatureType) && target.getValue().equals(sourceFeatureValue)) {
                    Value droppedValue = sourceFeatureType.drop(sourceFeatureValue);
                    f = FeatureFactory.copyOf(sourceFeature, droppedValue);
                } else {
                    f = DataUtils.dropOnLinearizedFeatures(target, sourceFeature);
                }
            } else if (Type.COMPOSITE.equals((Object)sourceFeatureType)) {
                List nestedFeatures = (List)sourceFeatureValue.getUnderlyingObject();
                f = FeatureFactory.newCompositeFeature(sourceFeatureName, DataUtils.dropFeature(nestedFeatures, target));
            } else {
                f = FeatureFactory.copyOf(sourceFeature, sourceFeatureValue);
            }
            newList.add(f);
        }
        return newList;
    }

    protected static Feature dropOnLinearizedFeatures(Feature target, Feature sourceFeature) {
        Feature f = null;
        List<Feature> linearizedFeatures = DataUtils.getLinearizedFeatures(List.of(sourceFeature));
        int i = 0;
        for (Feature linearizedFeature : linearizedFeatures) {
            if (target.getValue().equals(linearizedFeature.getValue())) {
                linearizedFeatures.set(i, FeatureFactory.copyOf(linearizedFeature, linearizedFeature.getType().drop(target.getValue())));
                f = FeatureFactory.newCompositeFeature(target.getName(), linearizedFeatures);
                break;
            }
            ++i;
        }
        if (f == null) {
            f = FeatureFactory.copyOf(sourceFeature, sourceFeature.getValue());
        }
        return f;
    }

    public static double hammingDistance(double[] x, double[] y) {
        if (x.length != y.length) {
            return Double.NaN;
        }
        double h = 0.0;
        for (int i = 0; i < x.length; ++i) {
            if (x[i] == y[i]) continue;
            h += 1.0;
        }
        return h;
    }

    public static double hammingDistance(String x, String y) {
        if (x.length() != y.length()) {
            return Double.NaN;
        }
        double h = 0.0;
        for (int i = 0; i < x.length(); ++i) {
            if (x.charAt(i) == y.charAt(i)) continue;
            h += 1.0;
        }
        return h;
    }

    public static double euclideanDistance(double[] x, double[] y) {
        if (x.length != y.length) {
            return Double.NaN;
        }
        double e = 0.0;
        for (int i = 0; i < x.length; ++i) {
            e += Math.pow(x[i] - y[i], 2.0);
        }
        return Math.sqrt(e);
    }

    public static double gaussianKernel(double x, double mu, double sigma) {
        return Math.exp(-Math.pow((x - mu) / sigma, 2.0) / 2.0) / (sigma * Math.sqrt(Math.PI * 2));
    }

    public static double exponentialSmoothingKernel(double x, double width) {
        return Math.sqrt(Math.exp(-Math.pow(x, 2.0) / Math.pow(width, 2.0)));
    }

    public static DataDistribution generateRandomDataDistribution(int noOfFeatures, int distributionSize, Random random) {
        LinkedList<FeatureDistribution> featureDistributions = new LinkedList<FeatureDistribution>();
        for (int i = 0; i < noOfFeatures; ++i) {
            double[] doubles = DataUtils.generateData(random.nextDouble(), random.nextDouble(), distributionSize, random);
            Feature feature = FeatureFactory.newNumericalFeature("f_" + i, Double.NaN);
            NumericFeatureDistribution featureDistribution = new NumericFeatureDistribution(feature, doubles);
            featureDistributions.add(featureDistribution);
        }
        return new IndependentFeaturesDataDistribution(featureDistributions);
    }

    public static List<PredictionInput> linearizeInputs(List<PredictionInput> predictionInputs) {
        LinkedList<PredictionInput> newInputs = new LinkedList<PredictionInput>();
        for (PredictionInput predictionInput : predictionInputs) {
            List<Feature> originalFeatures = predictionInput.getFeatures();
            List<Feature> flattenedFeatures = DataUtils.getLinearizedFeatures(originalFeatures);
            newInputs.add(new PredictionInput(flattenedFeatures));
        }
        return newInputs;
    }

    public static List<Feature> getLinearizedFeatures(List<Feature> originalFeatures) {
        LinkedList<Feature> flattenedFeatures = new LinkedList<Feature>();
        for (Feature f : originalFeatures) {
            DataUtils.linearizeFeature(flattenedFeatures, f);
        }
        return flattenedFeatures;
    }

    private static void linearizeFeature(List<Feature> flattenedFeatures, Feature f) {
        if (Type.UNDEFINED.equals((Object)f.getType())) {
            if (f.getValue().getUnderlyingObject() instanceof Feature) {
                DataUtils.linearizeFeature(flattenedFeatures, (Feature)f.getValue().getUnderlyingObject());
            } else {
                flattenedFeatures.add(f);
            }
        } else if (Type.COMPOSITE.equals((Object)f.getType())) {
            if (f.getValue().getUnderlyingObject() instanceof List) {
                List features = (List)f.getValue().getUnderlyingObject();
                for (Feature feature : features) {
                    DataUtils.linearizeFeature(flattenedFeatures, feature);
                }
            } else {
                flattenedFeatures.add(f);
            }
        } else {
            flattenedFeatures.add(f);
        }
    }

    public static List<Prediction> getPredictions(List<PredictionInput> inputs, List<PredictionOutput> os) {
        return IntStream.range(0, os.size()).mapToObj(i -> new Prediction((PredictionInput)inputs.get(i), (PredictionOutput)os.get(i))).collect(Collectors.toList());
    }

    public static <T> List<T> sampleWithReplacement(List<T> values, int sampleSize, Random random) {
        if (sampleSize <= 0 || values.isEmpty()) {
            return Collections.emptyList();
        }
        return random.ints(sampleSize, 0, values.size()).mapToObj(values::get).collect(Collectors.toList());
    }

    public static void toCSV(PartialDependenceGraph partialDependenceGraph, Path path) throws IOException {
        try (OutputStream outputStream = Files.newOutputStream(path, new OpenOption[0]);){
            List<Value> xAxis = partialDependenceGraph.getX();
            List<Value> yAxis = partialDependenceGraph.getY();
            outputStream.write("feature,output\n".getBytes(StandardCharsets.UTF_8));
            for (int i = 0; i < xAxis.size(); ++i) {
                String line = xAxis.get(i).asString().replace(",", "") + "," + yAxis.get(i).asString().replace(",", "") + "\n";
                outputStream.write(line.getBytes(StandardCharsets.UTF_8));
            }
            outputStream.flush();
        }
    }
}

