/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.global.lime;

import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.global.lime.AggregatedLimeExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;

class AggregatedLimeExplainerTest {
    AggregatedLimeExplainerTest() {
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testExplainWithMetadata(int seed) throws ExecutionException, InterruptedException {
        final Random random = new Random();
        random.setSeed(seed);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        PredictionProviderMetadata metadata = new PredictionProviderMetadata(){

            public DataDistribution getDataDistribution() {
                return DataUtils.generateRandomDataDistribution((int)3, (int)100, (Random)random);
            }

            public PredictionInput getInputShape() {
                LinkedList<Feature> features = new LinkedList<Feature>();
                features.add(FeatureFactory.newNumericalFeature((String)"f0", (Number)0));
                features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)0));
                features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)0));
                return new PredictionInput(features);
            }

            public PredictionOutput getOutputShape() {
                LinkedList<Output> outputs = new LinkedList<Output>();
                outputs.add(new Output("sum-but1", Type.BOOLEAN, new Value((Object)false), 0.0));
                return new PredictionOutput(outputs);
            }
        };
        AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
        Map explain = (Map)aggregatedLimeExplainer.explainFromMetadata(sumSkipModel, metadata).get();
        Assertions.assertNotNull((Object)explain);
        Assertions.assertEquals((int)1, (int)explain.size());
        Assertions.assertTrue((boolean)explain.containsKey("sum-but1"));
        Saliency saliency = (Saliency)explain.get("sum-but1");
        Assertions.assertNotNull((Object)saliency);
        List collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
        Assertions.assertFalse((boolean)collect.contains("f1"));
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testExplainWithPredictions(int seed) throws ExecutionException, InterruptedException {
        Random random = new Random();
        random.setSeed(seed);
        PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
        DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution((int)3, (int)100, (Random)random);
        List samples = dataDistribution.sample(10);
        List predictionOutputs = (List)sumSkipModel.predictAsync(samples).get();
        List predictions = DataUtils.getPredictions((List)samples, (List)predictionOutputs);
        AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
        Map explain = (Map)aggregatedLimeExplainer.explainFromPredictions(sumSkipModel, (Collection)predictions).get();
        Assertions.assertNotNull((Object)explain);
        Assertions.assertEquals((int)1, (int)explain.size());
        Assertions.assertTrue((boolean)explain.containsKey("sum-but1"));
        Saliency saliency = (Saliency)explain.get("sum-but1");
        Assertions.assertNotNull((Object)saliency);
        List collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
        Assertions.assertFalse((boolean)collect.contains("f1"));
    }
}

