/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

class DummyModelsLimeExplainerTest {
    DummyModelsLimeExplainerTest() {
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testMapOneFeatureToOutputRegression(int seed) throws Exception {
        Random random = new Random();
        random.setSeed(seed);
        int idx = 1;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)100));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)20));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)0.1));
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getFeaturePassModel(idx);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new Prediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testUnusedFeatureRegression(int seed) throws Exception {
        Random random = new Random();
        random.setSeed(seed);
        int idx = 2;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)100));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)20));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)10));
        PredictionProvider model = TestUtils.getSumSkipModel(idx);
        PredictionInput input = new PredictionInput(features);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new Prediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(1000).withPerturbationContext(new PerturbationContext(random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testMapOneFeatureToOutputClassification(int seed) throws Exception {
        Random random = new Random();
        random.setSeed(seed);
        int idx = 1;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)1));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)1));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)3));
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getEvenFeatureModel(idx);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new Prediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(random, 2));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testTextSpamClassification(int seed) throws Exception {
        Random random = new Random();
        random.setSeed(seed);
        LinkedList<Feature> features = new LinkedList<Feature>();
        Function<String, List> tokenizer = s -> Arrays.asList((String[])s.split(" ").clone());
        features.add(FeatureFactory.newFulltextFeature((String)"f1", (String)"we go here and there", tokenizer));
        features.add(FeatureFactory.newFulltextFeature((String)"f2", (String)"please give me some money", tokenizer));
        features.add(FeatureFactory.newFulltextFeature((String)"f3", (String)"dear friend, please reply", tokenizer));
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getDummyTextClassifier();
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new Prediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(1000).withPerturbationContext(new PerturbationContext(random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getPositiveFeatures(1);
            Assertions.assertEquals((int)1, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testUnusedFeatureClassification(int seed) throws Exception {
        Random random = new Random();
        random.setSeed(seed);
        int idx = 2;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)6));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)3));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)5));
        PredictionProvider model = TestUtils.getEvenSumModel(idx);
        PredictionInput input = new PredictionInput(features);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new Prediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(1000).withPerturbationContext(new PerturbationContext(random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testFixedOutput(int seed) throws Exception {
        Random random = new Random();
        random.setSeed(seed);
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)6));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)3));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)5));
        PredictionProvider model = TestUtils.getFixedOutputClassifier();
        PredictionInput input = new PredictionInput(features);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new Prediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(1000).withPerturbationContext(new PerturbationContext(random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            for (FeatureImportance featureImportance : topFeatures) {
                Assertions.assertEquals((double)0.0, (double)featureImportance.getScore());
            }
            Assertions.assertEquals((double)0.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    }
}

