/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

class DummyModelsLimeExplainerTest {
    DummyModelsLimeExplainerTest() {
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testMapOneFeatureToOutputRegression(long seed) throws Exception {
        Random random = new Random();
        int idx = 1;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(TestUtils.getMockedNumericFeature(100.0));
        features.add(TestUtils.getMockedNumericFeature(20.0));
        features.add(TestUtils.getMockedNumericFeature(0.1));
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getFeaturePassModel(idx);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, (Prediction)prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        for (int i = 0; i < 100; ++i) {
            LinkedList<Feature> fs = new LinkedList<Feature>();
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            inputs.add(new PredictionInput(fs));
        }
        PredictionInputsDataDistribution distribution = new PredictionInputsDataDistribution(inputs);
        int k = 2;
        int chunkSize = 10;
        String decision = "feature-" + idx;
        double precision = ExplainabilityMetrics.getLocalSaliencyPrecision((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)precision).isZero();
        double recall = ExplainabilityMetrics.getLocalSaliencyRecall((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)recall).isEqualTo(1.0);
        double f1 = ExplainabilityMetrics.getLocalSaliencyF1((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)f1).isZero();
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testUnusedFeatureRegression(long seed) throws Exception {
        Random random = new Random();
        int idx = 2;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(TestUtils.getMockedNumericFeature(100.0));
        features.add(TestUtils.getMockedNumericFeature(20.0));
        features.add(TestUtils.getMockedNumericFeature(10.0));
        PredictionProvider model = TestUtils.getSumSkipModel(idx);
        PredictionInput input = new PredictionInput(features);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, (Prediction)prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        for (int i = 0; i < 100; ++i) {
            LinkedList<Feature> fs = new LinkedList<Feature>();
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            inputs.add(new PredictionInput(fs));
        }
        PredictionInputsDataDistribution distribution = new PredictionInputsDataDistribution(inputs);
        int k = 2;
        int chunkSize = 10;
        String decision = "sum-but" + idx;
        double precision = ExplainabilityMetrics.getLocalSaliencyPrecision((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)precision).isEqualTo(1.0);
        double recall = ExplainabilityMetrics.getLocalSaliencyRecall((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)recall).isEqualTo(1.0);
        double f1 = ExplainabilityMetrics.getLocalSaliencyF1((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)f1).isEqualTo(1.0);
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testMapOneFeatureToOutputClassification(long seed) throws Exception {
        Random random = new Random();
        int idx = 1;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)1));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)1));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)3));
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getEvenFeatureModel(idx);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 2));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        int topK = 1;
        TestUtils.assertLimeStability(model, (Prediction)prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        for (int i = 0; i < 100; ++i) {
            LinkedList<Feature> fs = new LinkedList<Feature>();
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            inputs.add(new PredictionInput(fs));
        }
        PredictionInputsDataDistribution distribution = new PredictionInputsDataDistribution(inputs);
        int k = 2;
        int chunkSize = 10;
        String decision = "feature-" + idx;
        double precision = ExplainabilityMetrics.getLocalSaliencyPrecision((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)precision).isEqualTo(1.0);
        double recall = ExplainabilityMetrics.getLocalSaliencyRecall((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)recall).isEqualTo(1.0);
        double f1 = ExplainabilityMetrics.getLocalSaliencyF1((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)f1).isEqualTo(1.0);
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testTextSpamClassification(long seed) throws Exception {
        Random random = new Random();
        LinkedList<Feature> features = new LinkedList<Feature>();
        Function<String, List> tokenizer = s -> Arrays.asList((String[])s.split(" ").clone());
        features.add(FeatureFactory.newFulltextFeature((String)"f1", (String)"we go here and there", tokenizer));
        features.add(FeatureFactory.newFulltextFeature((String)"f2", (String)"please give me some money", tokenizer));
        features.add(FeatureFactory.newFulltextFeature((String)"f3", (String)"dear friend, please reply", tokenizer));
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getDummyTextClassifier();
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getPositiveFeatures(1);
            Assertions.assertEquals((int)1, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.2;
        TestUtils.assertLimeStability(model, (Prediction)prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        for (int i = 0; i < 100; ++i) {
            LinkedList<Feature> fs = new LinkedList<Feature>();
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            inputs.add(new PredictionInput(fs));
        }
        PredictionInputsDataDistribution distribution = new PredictionInputsDataDistribution(inputs);
        int k = 2;
        int chunkSize = 10;
        String decision = "spam";
        double precision = ExplainabilityMetrics.getLocalSaliencyPrecision((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)precision).isEqualTo(1.0);
        double recall = ExplainabilityMetrics.getLocalSaliencyRecall((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)recall).isEqualTo(1.0);
        double f1 = ExplainabilityMetrics.getLocalSaliencyF1((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)f1).isEqualTo(1.0);
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testUnusedFeatureClassification(long seed) throws Exception {
        Random random = new Random();
        int idx = 2;
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)6));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)3));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)5));
        PredictionProvider model = TestUtils.getEvenSumModel(idx);
        PredictionInput input = new PredictionInput(features);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((int)3, (int)topFeatures.size());
            Assertions.assertEquals((double)1.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, (Prediction)prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        for (int i = 0; i < 100; ++i) {
            LinkedList<Feature> fs = new LinkedList<Feature>();
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            inputs.add(new PredictionInput(fs));
        }
        PredictionInputsDataDistribution distribution = new PredictionInputsDataDistribution(inputs);
        int k = 2;
        int chunkSize = 10;
        String decision = "sum-even-but" + idx;
        double precision = ExplainabilityMetrics.getLocalSaliencyPrecision((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)precision).isEqualTo(1.0);
        double recall = ExplainabilityMetrics.getLocalSaliencyRecall((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)recall).isEqualTo(1.0);
        double f1 = ExplainabilityMetrics.getLocalSaliencyF1((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)f1).isEqualTo(1.0);
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testFixedOutput(long seed) throws Exception {
        Random random = new Random();
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)6));
        features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)3));
        features.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)5));
        PredictionProvider model = TestUtils.getFixedOutputClassifier();
        PredictionInput input = new PredictionInput(features);
        List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)outputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            Assertions.assertNotNull((Object)saliency);
            List topFeatures = saliency.getTopFeatures(3);
            Assertions.assertEquals((double)0.0, (double)ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)topFeatures));
        }
        int topK = 1;
        double minimumPositiveStabilityRate = 0.5;
        double minimumNegativeStabilityRate = 0.5;
        TestUtils.assertLimeStability(model, (Prediction)prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        for (int i = 0; i < 100; ++i) {
            LinkedList<Feature> fs = new LinkedList<Feature>();
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            fs.add(TestUtils.getMockedNumericFeature());
            inputs.add(new PredictionInput(fs));
        }
        PredictionInputsDataDistribution distribution = new PredictionInputsDataDistribution(inputs);
        int k = 2;
        int chunkSize = 10;
        String decision = "class";
        double precision = ExplainabilityMetrics.getLocalSaliencyPrecision((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)precision).isEqualTo(1.0);
        double recall = ExplainabilityMetrics.getLocalSaliencyRecall((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)recall).isEqualTo(1.0);
        double f1 = ExplainabilityMetrics.getLocalSaliencyF1((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)f1).isEqualTo(1.0);
    }
}

