/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.LocalSaliencyStability;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExplainabilityMetrics {
    private static final Logger LOGGER = LoggerFactory.getLogger(ExplainabilityMetrics.class);
    private static final double CONFIDENCE_DROP_RATIO = 0.2;

    private ExplainabilityMetrics() {
    }

    public static double quantifyExplainability(int inputCognitiveChunks, int outputCognitiveChunks, double interactionRatio) {
        return inputCognitiveChunks + outputCognitiveChunks > 0 ? 0.333 / (double)inputCognitiveChunks + 0.333 / (double)outputCognitiveChunks + 0.333 * (1.0 - interactionRatio) : 0.0;
    }

    public static double impactScore(PredictionProvider model, Prediction prediction, List<FeatureImportance> topFeatures) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionOutput> predictionOutputs;
        List<Feature> copy = List.copyOf(prediction.getInput().getFeatures());
        for (FeatureImportance featureImportance : topFeatures) {
            copy = DataUtils.dropFeature(copy, featureImportance.getFeature());
        }
        PredictionInput predictionInput = new PredictionInput(copy);
        try {
            predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            LOGGER.error("Impossible to obtain prediction {}", (Object)e.getMessage());
            throw e;
        }
        double impact = 0.0;
        for (PredictionOutput predictionOutput : predictionOutputs) {
            double size = predictionOutput.getOutputs().size();
            int i = 0;
            while ((double)i < size) {
                Output original = prediction.getOutput().getOutputs().get(i);
                Output modified = predictionOutput.getOutputs().get(i);
                impact += !original.getValue().asString().equals(modified.getValue().asString()) || modified.getScore() < original.getScore() * 0.2 ? 1.0 / size : 0.0;
                ++i;
            }
        }
        return impact;
    }

    public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
        double acc = 0.0;
        double evals = 0.0;
        for (Pair<Saliency, Prediction> pair : pairs) {
            Saliency saliency = (Saliency)pair.getLeft();
            Prediction prediction = (Prediction)pair.getRight();
            for (Output output : prediction.getOutput().getOutputs()) {
                Type type = output.getType();
                if (!Type.BOOLEAN.equals((Object)type)) continue;
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if (v >= 0.0 && predictorOutput >= 0.0 || v < 0.0 && predictorOutput < 0.0) {
                    acc += 1.0;
                }
                evals += 1.0;
            }
        }
        return evals == 0.0 ? 0.0 : acc / evals;
    }

    public static LocalSaliencyStability getLocalSaliencyStability(PredictionProvider model, Prediction prediction, LocalExplainer<Map<String, Saliency>> saliencyLocalExplainer, int topK, int runs) throws InterruptedException, ExecutionException, TimeoutException {
        Map<String, List<Saliency>> saliencies = ExplainabilityMetrics.getMultipleSaliencies(model, prediction, saliencyLocalExplainer, runs);
        LocalSaliencyStability saliencyStability = new LocalSaliencyStability(saliencies.keySet());
        for (Map.Entry<String, List<Saliency>> entry : saliencies.entrySet()) {
            for (int k = 1; k <= topK; ++k) {
                String decision = entry.getKey();
                List<Saliency> perDecisionSaliencies = entry.getValue();
                int finalK = k;
                Map<List<String>, Long> topKPositive = ExplainabilityMetrics.getTopKFeaturesFrequency(perDecisionSaliencies, s -> s.getPositiveFeatures(finalK));
                Pair<List<String>, Long> positiveMostFrequent = ExplainabilityMetrics.getMostFrequent(topKPositive);
                double positiveFrequencyRate = (double)((Long)positiveMostFrequent.getValue()).longValue() / (double)perDecisionSaliencies.size();
                Map<List<String>, Long> topKNegative = ExplainabilityMetrics.getTopKFeaturesFrequency(perDecisionSaliencies, s -> s.getNegativeFeatures(finalK));
                Pair<List<String>, Long> negativeMostFrequent = ExplainabilityMetrics.getMostFrequent(topKNegative);
                double negativeFrequencyRate = (double)((Long)negativeMostFrequent.getValue()).longValue() / (double)perDecisionSaliencies.size();
                List positiveFeatureNames = (List)positiveMostFrequent.getKey();
                List negativeFeatureNames = (List)negativeMostFrequent.getKey();
                saliencyStability.add(decision, k, positiveFeatureNames, positiveFrequencyRate, negativeFeatureNames, negativeFrequencyRate);
            }
        }
        return saliencyStability;
    }

    private static Map<String, List<Saliency>> getMultipleSaliencies(PredictionProvider model, Prediction prediction, LocalExplainer<Map<String, Saliency>> saliencyLocalExplainer, int runs) throws InterruptedException, ExecutionException, TimeoutException {
        HashMap<String, List<Saliency>> saliencies = new HashMap<String, List<Saliency>>();
        int skipped = 0;
        for (int i = 0; i < runs; ++i) {
            Map<String, Saliency> saliencyMap = saliencyLocalExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            for (Map.Entry<String, Saliency> saliencyEntry : saliencyMap.entrySet()) {
                List<FeatureImportance> topFeatures = saliencyEntry.getValue().getTopFeatures(1);
                if (!topFeatures.isEmpty() && topFeatures.get(0).getScore() != 0.0) {
                    if (saliencies.containsKey(saliencyEntry.getKey())) {
                        List localSaliencies = (List)saliencies.get(saliencyEntry.getKey());
                        ArrayList<Saliency> updatedSaliencies = new ArrayList<Saliency>(localSaliencies);
                        updatedSaliencies.add(saliencyEntry.getValue());
                        saliencies.put(saliencyEntry.getKey(), updatedSaliencies);
                        continue;
                    }
                    saliencies.put(saliencyEntry.getKey(), List.of(saliencyEntry.getValue()));
                    continue;
                }
                LOGGER.debug("skipping empty / zero saliency for {}", (Object)saliencyEntry.getKey());
                ++skipped;
            }
        }
        LOGGER.debug("skipped {} useless saliencies", (Object)skipped);
        return saliencies;
    }

    private static Map<List<String>, Long> getTopKFeaturesFrequency(List<Saliency> saliencies, Function<Saliency, List<FeatureImportance>> saliencyListFunction) {
        return saliencies.stream().map(saliencyListFunction).map(l -> l.stream().map(f -> f.getFeature().getName()).collect(Collectors.toList())).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
    }

    private static Pair<List<String>, Long> getMostFrequent(Map<List<String>, Long> collect) {
        Map.Entry maxEntry = Collections.max(collect.entrySet(), Map.Entry.comparingByValue());
        return Pair.of((Object)((List)maxEntry.getKey()), (Object)((Long)maxEntry.getValue()));
    }
}

