/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.LocalSaliencyStability;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExplainabilityMetrics {
    private static final Logger LOGGER = LoggerFactory.getLogger(ExplainabilityMetrics.class);
    private static final double CONFIDENCE_DROP_RATIO = 0.2;

    private ExplainabilityMetrics() {
    }

    public static double quantifyExplainability(int inputCognitiveChunks, int outputCognitiveChunks, double interactionRatio) {
        return inputCognitiveChunks + outputCognitiveChunks > 0 ? 0.333 / (double)inputCognitiveChunks + 0.333 / (double)outputCognitiveChunks + 0.333 * (1.0 - interactionRatio) : 0.0;
    }

    public static double impactScore(PredictionProvider model, Prediction prediction, List<FeatureImportance> topFeatures) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionOutput> predictionOutputs;
        List<Feature> copy = List.copyOf(prediction.getInput().getFeatures());
        for (FeatureImportance featureImportance : topFeatures) {
            copy = DataUtils.dropFeature(copy, featureImportance.getFeature());
        }
        PredictionInput predictionInput = new PredictionInput(copy);
        try {
            predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        }
        catch (ExecutionException | TimeoutException e) {
            LOGGER.error("Impossible to obtain prediction {}", (Object)e.getMessage());
            throw new IllegalStateException("Impossible to obtain prediction", e);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IllegalStateException("Impossible to obtain prediction (Thread interrupted)", e);
        }
        double impact = 0.0;
        for (PredictionOutput predictionOutput : predictionOutputs) {
            double size = predictionOutput.getOutputs().size();
            int i = 0;
            while ((double)i < size) {
                Output original = prediction.getOutput().getOutputs().get(i);
                Output modified = predictionOutput.getOutputs().get(i);
                impact += !original.getValue().asString().equals(modified.getValue().asString()) || modified.getScore() < original.getScore() * 0.2 ? 1.0 / size : 0.0;
                ++i;
            }
        }
        return impact;
    }

    public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
        double acc = 0.0;
        double evals = 0.0;
        for (Pair<Saliency, Prediction> pair : pairs) {
            Saliency saliency = (Saliency)pair.getLeft();
            Prediction prediction = (Prediction)pair.getRight();
            for (Output output : prediction.getOutput().getOutputs()) {
                Type type = output.getType();
                if (!Type.BOOLEAN.equals((Object)type)) continue;
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if (v >= 0.0 && predictorOutput >= 0.0 || v < 0.0 && predictorOutput < 0.0) {
                    acc += 1.0;
                }
                evals += 1.0;
            }
        }
        return evals == 0.0 ? 0.0 : acc / evals;
    }

    public static LocalSaliencyStability getLocalSaliencyStability(PredictionProvider model, Prediction prediction, LocalExplainer<Map<String, Saliency>> saliencyLocalExplainer, int topK, int runs) throws InterruptedException, ExecutionException, TimeoutException {
        Map<String, List<Saliency>> saliencies = ExplainabilityMetrics.getMultipleSaliencies(model, prediction, saliencyLocalExplainer, runs);
        LocalSaliencyStability saliencyStability = new LocalSaliencyStability(saliencies.keySet());
        for (Map.Entry<String, List<Saliency>> entry : saliencies.entrySet()) {
            for (int k = 1; k <= topK; ++k) {
                String decision = entry.getKey();
                List<Saliency> perDecisionSaliencies = entry.getValue();
                int finalK = k;
                Map<List<String>, Long> topKPositive = ExplainabilityMetrics.getTopKFeaturesFrequency(perDecisionSaliencies, s -> s.getPositiveFeatures(finalK));
                Pair<List<String>, Long> positiveMostFrequent = ExplainabilityMetrics.getMostFrequent(topKPositive);
                double positiveFrequencyRate = (double)((Long)positiveMostFrequent.getValue()).longValue() / (double)perDecisionSaliencies.size();
                Map<List<String>, Long> topKNegative = ExplainabilityMetrics.getTopKFeaturesFrequency(perDecisionSaliencies, s -> s.getNegativeFeatures(finalK));
                Pair<List<String>, Long> negativeMostFrequent = ExplainabilityMetrics.getMostFrequent(topKNegative);
                double negativeFrequencyRate = (double)((Long)negativeMostFrequent.getValue()).longValue() / (double)perDecisionSaliencies.size();
                List positiveFeatureNames = (List)positiveMostFrequent.getKey();
                List negativeFeatureNames = (List)negativeMostFrequent.getKey();
                saliencyStability.add(decision, k, positiveFeatureNames, positiveFrequencyRate, negativeFeatureNames, negativeFrequencyRate);
            }
        }
        return saliencyStability;
    }

    private static Map<String, List<Saliency>> getMultipleSaliencies(PredictionProvider model, Prediction prediction, LocalExplainer<Map<String, Saliency>> saliencyLocalExplainer, int runs) throws InterruptedException, ExecutionException, TimeoutException {
        HashMap<String, List<Saliency>> saliencies = new HashMap<String, List<Saliency>>();
        int skipped = 0;
        for (int i = 0; i < runs; ++i) {
            Map<String, Saliency> saliencyMap = saliencyLocalExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            for (Map.Entry<String, Saliency> saliencyEntry : saliencyMap.entrySet()) {
                List<FeatureImportance> topFeatures = saliencyEntry.getValue().getTopFeatures(1);
                if (!topFeatures.isEmpty() && topFeatures.get(0).getScore() != 0.0) {
                    if (saliencies.containsKey(saliencyEntry.getKey())) {
                        List localSaliencies = (List)saliencies.get(saliencyEntry.getKey());
                        ArrayList<Saliency> updatedSaliencies = new ArrayList<Saliency>(localSaliencies);
                        updatedSaliencies.add(saliencyEntry.getValue());
                        saliencies.put(saliencyEntry.getKey(), updatedSaliencies);
                        continue;
                    }
                    saliencies.put(saliencyEntry.getKey(), List.of(saliencyEntry.getValue()));
                    continue;
                }
                LOGGER.debug("skipping empty / zero saliency for {}", (Object)saliencyEntry.getKey());
                ++skipped;
            }
        }
        LOGGER.debug("skipped {} useless saliencies", (Object)skipped);
        return saliencies;
    }

    private static Map<List<String>, Long> getTopKFeaturesFrequency(List<Saliency> saliencies, Function<Saliency, List<FeatureImportance>> saliencyListFunction) {
        return saliencies.stream().map(saliencyListFunction).map(l -> l.stream().map(f -> f.getFeature().getName()).collect(Collectors.toList())).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
    }

    private static Pair<List<String>, Long> getMostFrequent(Map<List<String>, Long> collect) {
        Map.Entry maxEntry = Collections.max(collect.entrySet(), Map.Entry.comparingByValue());
        return Pair.of((Object)((List)maxEntry.getKey()), (Object)((Long)maxEntry.getValue()));
    }

    public static double getLocalSaliencyRecall(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
        List<Prediction> sorted = ExplainabilityMetrics.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
        ArrayList<Prediction> topChunk = new ArrayList<Prediction>(sorted.subList(0, chunkSize));
        ArrayList<Prediction> bottomChunk = new ArrayList<Prediction>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
        double truePositives = 0.0;
        double falseNegatives = 0.0;
        int currentChunk = 0;
        for (Prediction prediction : topChunk) {
            PredictionOutput predictionOutput;
            Optional<Output> optionalNewOutput;
            PredictionInput input;
            Optional<Output> optionalOutput = prediction.getOutput().getByName(outputName);
            if (!optionalOutput.isPresent()) continue;
            Output output = optionalOutput.get();
            Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (!stringSaliencyMap.containsKey(outputName)) continue;
            Saliency saliency = stringSaliencyMap.get(outputName);
            List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted((f1, f2) -> Double.compare(f2.getScore(), f1.getScore())).limit(k).collect(Collectors.toList());
            PredictionInput maskedInput = ExplainabilityMetrics.maskInput(topFeatures, input = ((Prediction)bottomChunk.get(currentChunk)).getInput());
            List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (!predictionOutputList.isEmpty() && (optionalNewOutput = (predictionOutput = predictionOutputList.get(0)).getByName(outputName)).isPresent()) {
                Output newOutput = optionalOutput.get();
                if (output.getValue().equals(newOutput.getValue())) {
                    truePositives += 1.0;
                } else {
                    falseNegatives += 1.0;
                }
            }
            ++currentChunk;
        }
        if (truePositives + falseNegatives > 0.0) {
            return truePositives / (truePositives + falseNegatives);
        }
        return Double.NaN;
    }

    private static PredictionInput maskInput(List<FeatureImportance> topFeatures, PredictionInput input) {
        ArrayList<Feature> importantFeatures = new ArrayList<Feature>();
        for (FeatureImportance featureImportance : topFeatures) {
            importantFeatures.add(featureImportance.getFeature());
        }
        return ExplainabilityMetrics.replaceAllFeatures(importantFeatures, input);
    }

    private static List<Prediction> getScoreSortedPredictions(String outputName, PredictionProvider predictionProvider, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionInput> inputs = dataDistribution.getAllSamples();
        List<PredictionOutput> predictionOutputs = predictionProvider.predictAsync(inputs).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
        List<Prediction> predictions = DataUtils.getPredictions(inputs, predictionOutputs);
        return predictions.stream().sorted((p1, p2) -> {
            Optional<Output> optionalOutput1 = p1.getOutput().getByName(outputName);
            Optional<Output> optionalOutput2 = p2.getOutput().getByName(outputName);
            if (optionalOutput1.isPresent() && optionalOutput2.isPresent()) {
                Output o1 = optionalOutput1.get();
                Output o2 = optionalOutput2.get();
                return Double.compare(o2.getScore(), o1.getScore());
            }
            return 0;
        }).collect(Collectors.toList());
    }

    public static double getLocalSaliencyPrecision(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
        List<Prediction> sorted = ExplainabilityMetrics.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
        ArrayList<Prediction> topChunk = new ArrayList<Prediction>(sorted.subList(0, chunkSize));
        ArrayList<Prediction> bottomChunk = new ArrayList<Prediction>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
        double truePositives = 0.0;
        double falsePositives = 0.0;
        int currentChunk = 0;
        for (Prediction prediction : bottomChunk) {
            PredictionOutput predictionOutput;
            Optional<Output> newOptionalOutput;
            Prediction topPrediction;
            PredictionInput input;
            Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (!stringSaliencyMap.containsKey(outputName)) continue;
            Saliency saliency = stringSaliencyMap.get(outputName);
            List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted(Comparator.comparingDouble(FeatureImportance::getScore)).limit(k).collect(Collectors.toList());
            PredictionInput maskedInput = ExplainabilityMetrics.maskInput(topFeatures, input = (topPrediction = (Prediction)topChunk.get(currentChunk)).getInput());
            List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (!predictionOutputList.isEmpty() && (newOptionalOutput = (predictionOutput = predictionOutputList.get(0)).getByName(outputName)).isPresent()) {
                Output newOutput = newOptionalOutput.get();
                Optional<Output> optionalOutput = topPrediction.getOutput().getByName(outputName);
                if (optionalOutput.isPresent()) {
                    Output output = optionalOutput.get();
                    if (output.getValue().equals(newOutput.getValue())) {
                        truePositives += 1.0;
                    } else {
                        falsePositives += 1.0;
                    }
                }
            }
            ++currentChunk;
        }
        if (truePositives + falsePositives > 0.0) {
            return truePositives / (truePositives + falsePositives);
        }
        return Double.NaN;
    }

    public static double getLocalSaliencyF1(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
        double recall;
        double precision = ExplainabilityMetrics.getLocalSaliencyPrecision(outputName, predictionProvider, localExplainer, dataDistribution, k, chunkSize);
        if (Double.isFinite(precision + (recall = ExplainabilityMetrics.getLocalSaliencyRecall(outputName, predictionProvider, localExplainer, dataDistribution, k, chunkSize))) && precision + recall > 0.0) {
            return 2.0 * precision * recall / (precision + recall);
        }
        return Double.NaN;
    }

    private static PredictionInput replaceAllFeatures(List<Feature> importantFeatures, PredictionInput input) {
        List<Feature> features = List.copyOf(input.getFeatures());
        for (Feature f : importantFeatures) {
            features = DataUtils.replaceFeatures(f, features);
        }
        return new PredictionInput(features);
    }
}

