/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.tuple.Pair;
import org.assertj.core.api.Assertions;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;

class ExplainabilityMetricsTest {
    ExplainabilityMetricsTest() {
    }

    @Test
    void testExplainabilityNoExplanation() {
        double v = ExplainabilityMetrics.quantifyExplainability((int)0, (int)0, (double)0.0);
        org.junit.jupiter.api.Assertions.assertFalse((boolean)Double.isNaN(v));
        org.junit.jupiter.api.Assertions.assertFalse((boolean)Double.isInfinite(v));
        org.junit.jupiter.api.Assertions.assertEquals((double)0.0, (double)v);
    }

    @Test
    void testExplainabilityNoExplanationWithInteraction() {
        double v = ExplainabilityMetrics.quantifyExplainability((int)0, (int)0, (double)1.0);
        org.junit.jupiter.api.Assertions.assertFalse((boolean)Double.isNaN(v));
        org.junit.jupiter.api.Assertions.assertFalse((boolean)Double.isInfinite(v));
        org.junit.jupiter.api.Assertions.assertEquals((double)0.0, (double)v);
    }

    @Test
    void testExplainabilitySameIOChunksNoInteraction() {
        double v = ExplainabilityMetrics.quantifyExplainability((int)10, (int)10, (double)0.0);
        org.junit.jupiter.api.Assertions.assertFalse((boolean)Double.isNaN(v));
        org.junit.jupiter.api.Assertions.assertFalse((boolean)Double.isInfinite(v));
        Assertions.assertThat((double)v).isBetween(Double.valueOf(0.0), Double.valueOf(1.0));
    }

    @Test
    void testExplainabilitySameIOChunksWithInteraction() {
        double v = ExplainabilityMetrics.quantifyExplainability((int)10, (int)10, (double)0.5);
        org.junit.jupiter.api.Assertions.assertEquals((double)0.2331, (double)v, (double)1.0E-5);
    }

    @Test
    void testExplainabilityDifferentIOChunksNoInteraction() {
        double v = ExplainabilityMetrics.quantifyExplainability((int)3, (int)9, (double)0.0);
        org.junit.jupiter.api.Assertions.assertEquals((double)0.481, (double)v, (double)1.0E-5);
    }

    @Test
    void testExplainabilityDifferentIOChunksInteraction() {
        double v = ExplainabilityMetrics.quantifyExplainability((int)3, (int)9, (double)0.5);
        org.junit.jupiter.api.Assertions.assertEquals((double)0.3145, (double)v, (double)1.0E-5);
    }

    @Test
    void testFidelityWithTextClassifier() throws ExecutionException, InterruptedException, TimeoutException {
        LinkedList<Pair> pairs = new LinkedList<Pair>();
        LimeConfig limeConfig = new LimeConfig().withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        PredictionProvider model = TestUtils.getDummyTextClassifier();
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newFulltextFeature((String)"f-0", (String)"brown fox", s -> Arrays.asList(s.split(" "))));
        features.add(FeatureFactory.newTextFeature((String)"f-1", (String)"money"));
        PredictionInput input = new PredictionInput(features);
        Prediction prediction = new Prediction(input, (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            pairs.add(Pair.of((Object)saliency, (Object)prediction));
        }
        org.junit.jupiter.api.Assertions.assertDoesNotThrow(() -> ExplainabilityMetrics.classificationFidelity((List)pairs));
    }

    @Test
    void testFidelityWithEvenSumModel() throws ExecutionException, InterruptedException, TimeoutException {
        LinkedList<Pair> pairs = new LinkedList<Pair>();
        LimeConfig limeConfig = new LimeConfig().withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        PredictionProvider model = TestUtils.getEvenSumModel(1);
        LinkedList<Feature> features = new LinkedList<Feature>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-1", (Number)1));
        features.add(FeatureFactory.newNumericalFeature((String)"f-2", (Number)2));
        features.add(FeatureFactory.newNumericalFeature((String)"f-3", (Number)3));
        PredictionInput input = new PredictionInput(features);
        Prediction prediction = new Prediction(input, (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0));
        Map saliencyMap = (Map)limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            pairs.add(Pair.of((Object)saliency, (Object)prediction));
        }
        org.junit.jupiter.api.Assertions.assertDoesNotThrow(() -> ExplainabilityMetrics.classificationFidelity((List)pairs));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    void testBrokenPredict() {
        Config.INSTANCE.setAsyncTimeout(1L);
        Config.INSTANCE.setAsyncTimeUnit(TimeUnit.MILLISECONDS);
        Prediction emptyPrediction = new Prediction(new PredictionInput(Collections.emptyList()), new PredictionOutput(Collections.emptyList()));
        PredictionProvider brokenProvider = inputs -> CompletableFuture.supplyAsync(() -> {
            Awaitility.await().atLeast(1L, TimeUnit.SECONDS).until(() -> false);
            throw new RuntimeException("this should never happen");
        });
        List emptyFeatures = Collections.emptyList();
        try {
            org.junit.jupiter.api.Assertions.assertThrows(IllegalStateException.class, () -> ExplainabilityMetrics.impactScore((PredictionProvider)brokenProvider, (Prediction)emptyPrediction, (List)emptyFeatures));
        }
        finally {
            Config.INSTANCE.setAsyncTimeout(5L);
            Config.INSTANCE.setAsyncTimeUnit(Config.DEFAULT_ASYNC_TIMEUNIT);
        }
    }
}

