/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.DataUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExplainabilityMetrics {
    private static final Logger LOGGER = LoggerFactory.getLogger(ExplainabilityMetrics.class);
    private static final double CONFIDENCE_DROP_RATIO = 0.2;

    private ExplainabilityMetrics() {
    }

    public static double quantifyExplainability(int inputCognitiveChunks, int outputCognitiveChunks, double interactionRatio) {
        return inputCognitiveChunks + outputCognitiveChunks > 0 ? 0.333 / (double)inputCognitiveChunks + 0.333 / (double)outputCognitiveChunks + 0.333 * (1.0 - interactionRatio) : 0.0;
    }

    public static double impactScore(PredictionProvider model, Prediction prediction, List<FeatureImportance> topFeatures) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionOutput> predictionOutputs;
        List<Feature> copy = List.copyOf(prediction.getInput().getFeatures());
        for (FeatureImportance featureImportance : topFeatures) {
            copy = DataUtils.dropFeature(copy, featureImportance.getFeature());
        }
        PredictionInput predictionInput = new PredictionInput(copy);
        try {
            predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            LOGGER.error("Impossible to obtain prediction {}", (Object)e.getMessage());
            throw e;
        }
        double impact = 0.0;
        for (PredictionOutput predictionOutput : predictionOutputs) {
            double size = predictionOutput.getOutputs().size();
            int i = 0;
            while ((double)i < size) {
                Output original = prediction.getOutput().getOutputs().get(i);
                Output modified = predictionOutput.getOutputs().get(i);
                impact += !original.getValue().asString().equals(modified.getValue().asString()) || modified.getScore() < original.getScore() * 0.2 ? 1.0 / size : 0.0;
                ++i;
            }
        }
        return impact;
    }

    public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
        double acc = 0.0;
        double evals = 0.0;
        for (Pair<Saliency, Prediction> pair : pairs) {
            Saliency saliency = (Saliency)pair.getLeft();
            Prediction prediction = (Prediction)pair.getRight();
            for (Output output : prediction.getOutput().getOutputs()) {
                Type type = output.getType();
                if (!Type.BOOLEAN.equals((Object)type)) continue;
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if (v >= 0.0 && predictorOutput >= 0.0 || v < 0.0 && predictorOutput < 0.0) {
                    acc += 1.0;
                }
                evals += 1.0;
            }
        }
        return evals == 0.0 ? 0.0 : acc / evals;
    }
}

