/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.DataUtils;

public class ExplainabilityMetrics {
    private static final double CONFIDENCE_DROP_RATIO = 0.2;

    private ExplainabilityMetrics() {
    }

    public static double quantifyExplainability(int inputCognitiveChunks, int outputCognitiveChunks, double interactionRatio) {
        return inputCognitiveChunks + outputCognitiveChunks > 0 ? 0.333 / (double)inputCognitiveChunks + 0.333 / (double)outputCognitiveChunks + 0.333 * (1.0 - interactionRatio) : 0.0;
    }

    public static double impactScore(PredictionProvider model, Prediction prediction, List<FeatureImportance> topFeatures) {
        List<String> importantFeatureNames = topFeatures.stream().map(f -> f.getFeature().getName()).collect(Collectors.toList());
        LinkedList<Feature> newFeatures = new LinkedList<Feature>();
        for (Feature feature : prediction.getInput().getFeatures()) {
            Feature newFeature = DataUtils.dropFeature(feature, importantFeatureNames);
            newFeatures.add(newFeature);
        }
        PredictionInput predictionInput = new PredictionInput(newFeatures);
        List<PredictionOutput> predictionOutputs = model.predict(List.of(predictionInput));
        PredictionOutput predictionOutput = predictionOutputs.get(0);
        double impact = 0.0;
        double size = predictionOutput.getOutputs().size();
        int i = 0;
        while ((double)i < size) {
            Output original = prediction.getOutput().getOutputs().get(i);
            Output modified = predictionOutput.getOutputs().get(i);
            impact += !original.getValue().asString().equals(modified.getValue().asString()) || modified.getScore() < original.getScore() * 0.2 ? 1.0 : 0.0;
            ++i;
        }
        return impact / size;
    }

    public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
        double acc = 0.0;
        double evals = 0.0;
        for (Pair<Saliency, Prediction> pair : pairs) {
            Saliency saliency = pair.getLeft();
            Prediction prediction = pair.getRight();
            for (Output output : prediction.getOutput().getOutputs()) {
                Type type = output.getType();
                if (!Type.BOOLEAN.equals((Object)type)) continue;
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if (v >= 0.0 && predictorOutput >= 0.0 || v < 0.0 && predictorOutput < 0.0) {
                    acc += 1.0;
                }
                evals += 1.0;
            }
        }
        return evals == 0.0 ? 0.0 : acc / evals;
    }
}

