/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.LocalExplanationException;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.GenericFeatureDistribution;
import org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

class LimeExplainerTest {
    private static final int DEFAULT_NO_OF_PERTURBATIONS = 1;

    LimeExplainerTest() {
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testEmptyPrediction(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(random, 1)).withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        PredictionInput input = new PredictionInput(Collections.emptyList());
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        SimplePrediction prediction = new SimplePrediction(input, output);
        Assertions.assertThrows(LocalExplanationException.class, () -> LimeExplainerTest.lambda$testEmptyPrediction$0(limeExplainer, (Prediction)prediction, model));
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testNonEmptyInput(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(random, 1)).withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (int i = 0; i < 4; ++i) {
            features.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        SimplePrediction prediction = new SimplePrediction(input, output);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Assertions.assertNotNull((Object)saliencyMap);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testSparseBalance(int seed) throws InterruptedException, ExecutionException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        for (int nf = 1; nf < 4; ++nf) {
            int noOfSamples = 100;
            LimeConfig limeConfigNoPenalty = new LimeConfig().withPerturbationContext(new PerturbationContext(random, 1)).withSamples(noOfSamples).withPenalizeBalanceSparse(false);
            LimeExplainer limeExplainerNoPenalty = new LimeExplainer(limeConfigNoPenalty);
            ArrayList<Feature> features = new ArrayList<Feature>();
            for (int i = 0; i < nf; ++i) {
                features.add(TestUtils.getMockedNumericFeature(i));
            }
            PredictionInput input = new PredictionInput(features);
            PredictionProvider model = TestUtils.getSumSkipModel(0);
            PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
            SimplePrediction prediction = new SimplePrediction(input, output);
            Map saliencyMapNoPenalty = (Map)limeExplainerNoPenalty.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            AssertionsForClassTypes.assertThat((Object)saliencyMapNoPenalty).isNotNull();
            String decisionName = "sum-but0";
            Saliency saliencyNoPenalty = (Saliency)saliencyMapNoPenalty.get(decisionName);
            LimeConfig limeConfig = new LimeConfig().withSamples(noOfSamples).withPenalizeBalanceSparse(true);
            LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
            Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            AssertionsForClassTypes.assertThat((Object)saliencyMap).isNotNull();
            Saliency saliency = (Saliency)saliencyMap.get(decisionName);
            for (int i = 0; i < features.size(); ++i) {
                double score = ((FeatureImportance)saliency.getPerFeatureImportance().get(i)).getScore();
                double scoreNoPenalty = ((FeatureImportance)saliencyNoPenalty.getPerFeatureImportance().get(i)).getScore();
                AssertionsForClassTypes.assertThat((double)Math.abs(score)).isLessThanOrEqualTo(Math.abs(scoreNoPenalty));
            }
        }
    }

    @Test
    void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
        Random random = new Random();
        random.setSeed(4L);
        LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(random, 2)).withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        int nf = 4;
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (int i = 0; i < nf; ++i) {
            features.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        SimplePrediction prediction = new SimplePrediction(input, output);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        AssertionsForClassTypes.assertThat((Object)saliencyMap).isNotNull();
        String decisionName = "sum-but0";
        Saliency saliency = (Saliency)saliencyMap.get(decisionName);
        List perFeatureImportance = saliency.getPerFeatureImportance();
        for (FeatureImportance featureImportance : perFeatureImportance) {
            AssertionsForClassTypes.assertThat((double)featureImportance.getScore()).isBetween(Double.valueOf(-1.0), Double.valueOf(1.0));
        }
    }

    @Test
    void testWithDataDistribution() throws InterruptedException, ExecutionException, TimeoutException {
        Random random = new Random();
        random.setSeed(4L);
        PerturbationContext perturbationContext = new PerturbationContext(random, 1);
        ArrayList<GenericFeatureDistribution> featureDistributions = new ArrayList<GenericFeatureDistribution>();
        int nf = 4;
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (int i = 0; i < nf; ++i) {
            Feature mockedNumericFeature = TestUtils.getMockedNumericFeature(i);
            features.add(mockedNumericFeature);
            ArrayList<Value> values = new ArrayList<Value>();
            for (int r = 0; r < 4; ++r) {
                values.add(Type.NUMBER.randomValue(perturbationContext));
            }
            featureDistributions.add(new GenericFeatureDistribution(mockedNumericFeature, values));
        }
        IndependentFeaturesDataDistribution dataDistribution = new IndependentFeaturesDataDistribution(featureDistributions);
        LimeConfig limeConfig = new LimeConfig().withDataDistribution((DataDistribution)dataDistribution).withPerturbationContext(perturbationContext).withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        SimplePrediction prediction = new SimplePrediction(input, output);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        AssertionsForClassTypes.assertThat((Object)saliencyMap).isNotNull();
        String decisionName = "sum-but0";
        Saliency saliency = (Saliency)saliencyMap.get(decisionName);
        AssertionsForClassTypes.assertThat((Object)saliency).isNotNull();
    }

    private static /* synthetic */ void lambda$testEmptyPrediction$0(LimeExplainer limeExplainer, Prediction prediction, PredictionProvider model) throws Throwable {
        limeExplainer.explainAsync(prediction, model);
    }
}

