/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.LocalSaliencyStability;

class LimeStabilityTest {
    static final double TOP_FEATURE_THRESHOLD = 0.9;

    LimeStabilityTest() {
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testStabilityWithNumericData(long seed) throws Exception {
        Random random = new Random();
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        LinkedList<Feature> featureList = new LinkedList<Feature>();
        for (int i = 0; i < 5; ++i) {
            featureList.add(TestUtils.getMockedNumericFeature(i));
        }
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        this.assertStable(limeExplainer, sumSkipModel, featureList);
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testStabilityWithTextData(long seed) throws Exception {
        Random random = new Random();
        PredictionProvider sumSkipModel = TestUtils.getDummyTextClassifier();
        LinkedList<Feature> featureList = new LinkedList<Feature>();
        for (int i = 0; i < 4; ++i) {
            featureList.add(TestUtils.getMockedTextFeature("foo " + i));
        }
        featureList.add(TestUtils.getMockedTextFeature("money"));
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        this.assertStable(limeExplainer, sumSkipModel, featureList);
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testAdaptiveVariance(long seed) throws Exception {
        Random random = new Random();
        PerturbationContext perturbationContext = new PerturbationContext(Long.valueOf(seed), random, 1);
        int samples = 1;
        int retries = 4;
        LimeConfig limeConfig = new LimeConfig().withSamples(samples).withPerturbationContext(perturbationContext).withRetries(retries).withAdaptiveVariance(true);
        LimeExplainer adaptiveVarianceLE = new LimeExplainer(limeConfig);
        LinkedList<Feature> features = new LinkedList<Feature>();
        for (int i = 0; i < 4; ++i) {
            features.add(FeatureFactory.newNumericalFeature((String)("f-" + i), (Number)2));
        }
        PredictionProvider model = TestUtils.getEvenSumModel(0);
        this.assertStable(adaptiveVarianceLE, model, features);
    }

    private void assertStable(LimeExplainer limeExplainer, PredictionProvider model, List<Feature> featureList) throws Exception {
        PredictionInput input = new PredictionInput(featureList);
        List predictionOutputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (PredictionOutput predictionOutput : predictionOutputs) {
            Map.Entry entry2;
            SimplePrediction prediction = new SimplePrediction(input, predictionOutput);
            LinkedList saliencies = new LinkedList();
            for (int i = 0; i < 100; ++i) {
                Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
                saliencies.addAll(saliencyMap.values());
            }
            LinkedList names = new LinkedList();
            saliencies.stream().map(s -> s.getPositiveFeatures(1)).filter(f -> !f.isEmpty()).forEach(f -> names.add(((FeatureImportance)f.get(0)).getFeature().getName()));
            Map frequencyMap = names.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            boolean topFeature = false;
            for (Map.Entry entry2 : frequencyMap.entrySet()) {
                if (!((double)entry2.getValue().longValue() >= 0.9)) continue;
                topFeature = true;
                break;
            }
            org.junit.jupiter.api.Assertions.assertTrue((boolean)topFeature);
            ArrayList<Double> impacts = new ArrayList<Double>(saliencies.size());
            entry2 = saliencies.iterator();
            while (entry2.hasNext()) {
                Saliency saliency = (Saliency)entry2.next();
                double v = ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)saliency.getTopFeatures(2));
                impacts.add(v);
            }
            Map impactMap = impacts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            boolean topImpact = false;
            for (Map.Entry entry3 : impactMap.entrySet()) {
                if (!((double)entry3.getValue().longValue() >= 0.9)) continue;
                topImpact = true;
                break;
            }
            org.junit.jupiter.api.Assertions.assertTrue((boolean)topImpact);
        }
    }

    @ParameterizedTest
    @ValueSource(longs={0L, 1L, 2L, 3L, 4L})
    void testStabilityDeterministic(long seed) throws Exception {
        ArrayList<LocalSaliencyStability> stabilities = new ArrayList<LocalSaliencyStability>();
        for (int j = 0; j < 2; ++j) {
            Random random = new Random();
            PredictionProvider model = TestUtils.getSumSkipModel(0);
            LinkedList<Feature> featureList = new LinkedList<Feature>();
            for (int i = 0; i < 5; ++i) {
                featureList.add(TestUtils.getMockedNumericFeature(i));
            }
            PredictionInput input = new PredictionInput(featureList);
            List predictionOutputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)predictionOutputs.get(0));
            LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
            LimeExplainer explainer = new LimeExplainer(limeConfig);
            LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability((PredictionProvider)model, (Prediction)prediction, (LocalExplainer)explainer, (int)2, (int)10);
            stabilities.add(stability);
        }
        LocalSaliencyStability first = (LocalSaliencyStability)stabilities.get(0);
        LocalSaliencyStability second = (LocalSaliencyStability)stabilities.get(1);
        String decisionName = "sum-but0";
        Assertions.assertThat((double)first.getNegativeStabilityScore(decisionName, 1)).isEqualTo(second.getNegativeStabilityScore(decisionName, 1));
        Assertions.assertThat((double)first.getPositiveStabilityScore(decisionName, 1)).isEqualTo(second.getPositiveStabilityScore(decisionName, 1));
        Assertions.assertThat((double)first.getNegativeStabilityScore(decisionName, 2)).isEqualTo(second.getNegativeStabilityScore(decisionName, 2));
        Assertions.assertThat((double)first.getPositiveStabilityScore(decisionName, 2)).isEqualTo(second.getPositiveStabilityScore(decisionName, 2));
    }
}

