/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

class LimeStabilityTest {
    static final double TOP_FEATURE_THRESHOLD = 0.9;

    LimeStabilityTest() {
    }

    @Test
    void testStabilityWithNumericData() throws Exception {
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
        LinkedList<Feature> featureList = new LinkedList<Feature>();
        for (int i = 0; i < 5; ++i) {
            featureList.add(TestUtils.getMockedNumericFeature(i));
        }
        this.assertStable(sumSkipModel, featureList);
    }

    @Test
    void testStabilityWithTextData() throws Exception {
        PredictionProvider sumSkipModel = TestUtils.getDummyTextClassifier();
        LinkedList<Feature> featureList = new LinkedList<Feature>();
        for (int i = 0; i < 4; ++i) {
            featureList.add(TestUtils.getMockedTextFeature("foo " + i));
        }
        featureList.add(TestUtils.getMockedTextFeature("money"));
        this.assertStable(sumSkipModel, featureList);
    }

    private void assertStable(PredictionProvider model, List<Feature> featureList) throws Exception {
        Random random = new Random();
        for (int seed = 0; seed < 5; ++seed) {
            random.setSeed(seed);
            LimeExplainer limeExplainer = new LimeExplainer(10, 1, random);
            PredictionInput input = new PredictionInput(featureList);
            List predictionOutputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            for (PredictionOutput predictionOutput : predictionOutputs) {
                Map.Entry entry2;
                Prediction prediction = new Prediction(input, predictionOutput);
                LinkedList saliencies = new LinkedList();
                for (int i = 0; i < 100; ++i) {
                    Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
                    saliencies.addAll(saliencyMap.values());
                }
                LinkedList names = new LinkedList();
                saliencies.stream().map(s -> s.getPositiveFeatures(1)).forEach(f -> names.add(((FeatureImportance)f.get(0)).getFeature().getName()));
                Map frequencyMap = names.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
                boolean topFeature = false;
                for (Map.Entry entry2 : frequencyMap.entrySet()) {
                    if (!((double)entry2.getValue().longValue() >= 0.9)) continue;
                    topFeature = true;
                    break;
                }
                Assertions.assertTrue((boolean)topFeature);
                ArrayList<Double> impacts = new ArrayList<Double>(saliencies.size());
                entry2 = saliencies.iterator();
                while (entry2.hasNext()) {
                    Saliency saliency = (Saliency)entry2.next();
                    double v = ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)saliency.getTopFeatures(2));
                    impacts.add(v);
                }
                Map impactMap = impacts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
                boolean topImpact = false;
                for (Map.Entry entry3 : impactMap.entrySet()) {
                    if (!((double)entry3.getValue().longValue() >= 0.9)) continue;
                    topImpact = true;
                    break;
                }
                Assertions.assertTrue((boolean)topImpact);
            }
        }
    }
}

