/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

class LimeStabilityTest {
    static final double TOP_FEATURE_THRESHOLD = 0.9;

    LimeStabilityTest() {
    }

    @Test
    void testStabilityWithNumericData() throws Exception {
        Random random = new Random();
        for (int seed = 0; seed < 5; ++seed) {
            random.setSeed(seed);
            PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
            LinkedList<Feature> featureList = new LinkedList<Feature>();
            for (int i = 0; i < 5; ++i) {
                featureList.add(TestUtils.getMockedNumericFeature(i));
            }
            LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
            LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
            this.assertStable(limeExplainer, sumSkipModel, featureList);
        }
    }

    @Test
    void testStabilityWithTextData() throws Exception {
        Random random = new Random();
        for (int seed = 0; seed < 5; ++seed) {
            random.setSeed(seed);
            PredictionProvider sumSkipModel = TestUtils.getDummyTextClassifier();
            LinkedList<Feature> featureList = new LinkedList<Feature>();
            for (int i = 0; i < 4; ++i) {
                featureList.add(TestUtils.getMockedTextFeature("foo " + i));
            }
            featureList.add(TestUtils.getMockedTextFeature("money"));
            LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
            LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
            this.assertStable(limeExplainer, sumSkipModel, featureList);
        }
    }

    @Test
    void testAdaptiveVariance() throws Exception {
        Random random = new Random();
        for (int seed = 0; seed < 5; ++seed) {
            random.setSeed(seed);
            PerturbationContext perturbationContext = new PerturbationContext(random, 1);
            int samples = 1;
            int retries = 4;
            LimeConfig limeConfig = new LimeConfig().withSamples(samples).withPerturbationContext(perturbationContext).withRetries(retries).withAdaptiveVariance(true);
            LimeExplainer adaptiveVarianceLE = new LimeExplainer(limeConfig);
            LinkedList<Feature> features = new LinkedList<Feature>();
            for (int i = 0; i < 4; ++i) {
                features.add(FeatureFactory.newNumericalFeature((String)("f-" + i), (Number)2));
            }
            PredictionProvider model = TestUtils.getEvenSumModel(0);
            this.assertStable(adaptiveVarianceLE, model, features);
        }
    }

    private void assertStable(LimeExplainer limeExplainer, PredictionProvider model, List<Feature> featureList) throws Exception {
        PredictionInput input = new PredictionInput(featureList);
        List predictionOutputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (PredictionOutput predictionOutput : predictionOutputs) {
            Map.Entry entry2;
            Prediction prediction = new Prediction(input, predictionOutput);
            LinkedList saliencies = new LinkedList();
            for (int i = 0; i < 100; ++i) {
                Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
                saliencies.addAll(saliencyMap.values());
            }
            LinkedList names = new LinkedList();
            saliencies.stream().map(s -> s.getPositiveFeatures(1)).filter(f -> !f.isEmpty()).forEach(f -> names.add(((FeatureImportance)f.get(0)).getFeature().getName()));
            Map frequencyMap = names.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            boolean topFeature = false;
            for (Map.Entry entry2 : frequencyMap.entrySet()) {
                if (!((double)entry2.getValue().longValue() >= 0.9)) continue;
                topFeature = true;
                break;
            }
            Assertions.assertTrue((boolean)topFeature);
            ArrayList<Double> impacts = new ArrayList<Double>(saliencies.size());
            entry2 = saliencies.iterator();
            while (entry2.hasNext()) {
                Saliency saliency = (Saliency)entry2.next();
                double v = ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)saliency.getTopFeatures(2));
                impacts.add(v);
            }
            Map impactMap = impacts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            boolean topImpact = false;
            for (Map.Entry entry3 : impactMap.entrySet()) {
                if (!((double)entry3.getValue().longValue() >= 0.9)) continue;
                topImpact = true;
                break;
            }
            Assertions.assertTrue((boolean)topImpact);
        }
    }
}

