/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime.optim;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.optim.RecordingLimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.mockito.Mockito;

class RecordingLimeExplainerTest {
    RecordingLimeExplainerTest() {
    }

    @Test
    void testRecordedPredictions() {
        RecordingLimeExplainer recordingLimeExplainer = new RecordingLimeExplainer(10);
        ArrayList<Prediction> allPredictions = new ArrayList<Prediction>();
        PredictionProvider model = (PredictionProvider)Mockito.mock(PredictionProvider.class);
        for (int i = 0; i < 15; ++i) {
            Prediction prediction = (Prediction)Mockito.mock(Prediction.class);
            allPredictions.add(prediction);
            try {
                recordingLimeExplainer.explainAsync(prediction, model).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
                continue;
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        Assertions.assertThat(allPredictions).hasSize(15);
        List recordedPredictions = recordingLimeExplainer.getRecordedPredictions();
        Assertions.assertThat((List)recordedPredictions).hasSize(10);
        Assertions.assertThat(allPredictions.subList(5, 15)).isEqualTo((Object)recordedPredictions);
    }

    @Test
    void testParallel() throws InterruptedException, ExecutionException, TimeoutException {
        int capacity = 10;
        RecordingLimeExplainer recordingLimeExplainer = new RecordingLimeExplainer(capacity);
        PredictionProvider model = (PredictionProvider)Mockito.mock(PredictionProvider.class);
        Callable<Object> callable = () -> {
            for (int i = 0; i < 10000; ++i) {
                Prediction prediction = (Prediction)Mockito.mock(Prediction.class);
                try {
                    recordingLimeExplainer.explainAsync(prediction, model).get(5L, Config.DEFAULT_ASYNC_TIMEUNIT);
                    continue;
                }
                catch (Exception exception) {
                    // empty catch block
                }
            }
            return null;
        };
        ArrayList<Future<Object>> futures = new ArrayList<Future<Object>>();
        ExecutorService executorService = Executors.newCachedThreadPool();
        for (int i = 0; i < 4; ++i) {
            futures.add(executorService.submit(callable));
        }
        for (Future future : futures) {
            future.get(1L, TimeUnit.MINUTES);
        }
        Assertions.assertThat((int)recordingLimeExplainer.getRecordedPredictions().size()).isEqualTo(capacity);
    }

    @Test
    void testQueue() {
        String[] strings;
        RecordingLimeExplainer.FixedSizeConcurrentLinkedDeque queue = new RecordingLimeExplainer.FixedSizeConcurrentLinkedDeque(5);
        for (String s : strings = "a b c d e f g f".split(" ")) {
            queue.offer((Object)s);
        }
        Assertions.assertThat((Collection)queue).containsExactly((Object[])"c d e f g".split(" "));
    }

    @ParameterizedTest
    @ValueSource(longs={0L})
    void testAutomaticConfigOptimization(long seed) throws Exception {
        PredictionProvider model = TestUtils.getSumThresholdModel(10.0, 10.0);
        PerturbationContext pc = new PerturbationContext(Long.valueOf(seed), new Random(), 1);
        LimeConfig config = new LimeConfig().withPerturbationContext(pc);
        RecordingLimeExplainer limeExplainer = new RecordingLimeExplainer(2);
        for (int i = 0; i < 50; ++i) {
            LinkedList<Feature> features = new LinkedList<Feature>();
            features.add(TestUtils.getMockedNumericFeature(Type.NUMBER.randomValue(pc).asNumber()));
            features.add(TestUtils.getMockedNumericFeature(Type.NUMBER.randomValue(pc).asNumber()));
            features.add(TestUtils.getMockedNumericFeature(Type.NUMBER.randomValue(pc).asNumber()));
            PredictionInput input = new PredictionInput(features);
            List outputs = (List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            SimplePrediction prediction = new SimplePrediction(input, (PredictionOutput)outputs.get(0));
            Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            for (Saliency saliency : saliencyMap.values()) {
                org.junit.jupiter.api.Assertions.assertNotNull((Object)saliency);
            }
        }
        LimeConfig optimizedConfig = limeExplainer.getExecutionConfig();
        Assertions.assertThat((Object)optimizedConfig).isNotEqualTo((Object)config);
    }

    @Test
    void testEmptyInput() {
        RecordingLimeExplainer recordingLimeExplainer = new RecordingLimeExplainer(10);
        PredictionProvider model = (PredictionProvider)Mockito.mock(PredictionProvider.class);
        Prediction prediction = (Prediction)Mockito.mock(Prediction.class);
        Assertions.assertThatCode(() -> recordingLimeExplainer.explainAsync(prediction, model)).hasMessage("cannot explain a prediction whose input is empty");
    }

    @Test
    void testExplainNonOptimized() throws ExecutionException, InterruptedException, TimeoutException {
        RecordingLimeExplainer limeExplainer = new RecordingLimeExplainer(10);
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (int i = 0; i < 4; ++i) {
            features.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        SimplePrediction prediction = new SimplePrediction(input, output);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        org.junit.jupiter.api.Assertions.assertNotNull((Object)saliencyMap);
    }

    @Test
    void testEquals() {
        RecordingLimeExplainer o1 = new RecordingLimeExplainer(10);
        RecordingLimeExplainer o2 = new RecordingLimeExplainer(10);
        Assertions.assertThat((Object)o1).isNotEqualTo((Object)o2);
        LimeConfig config = new LimeConfig();
        RecordingLimeExplainer o3 = new RecordingLimeExplainer(config, 10);
        RecordingLimeExplainer o4 = new RecordingLimeExplainer(config, 10);
        Assertions.assertThat((Object)o3).isEqualTo((Object)o4);
    }
}

