/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime.optim;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.utils.DataUtils;

class LimeConfigOptimizerTest {
    LimeConfigOptimizerTest() {
    }

    @Test
    void testImpactOptimization() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore();
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testImpactOptimizationNoSampling() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withSampling(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testImpactOptimizationNoWeighting() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withWeighting(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testImpactOptimizationNoEncoding() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withEncoding(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testImpactOptimizationNoProximity() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forImpactScore().withProximity(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testImpactOptimizationNoEntity() {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().forImpactScore().withSampling(false).withEncoding(false).withWeighting(false).withProximity(false);
        org.junit.jupiter.api.Assertions.assertThrows(AssertionError.class, () -> this.assertConfigOptimized(limeConfigOptimizer));
    }

    @Test
    void testStabilityOptimization() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore();
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testStabilityOptimizationNoSampling() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withSampling(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testStabilityOptimizationNoWeighting() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withWeighting(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testStabilityOptimizationNoEncoding() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withEncoding(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testStabilityOptimizationNoProximity() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).forStabilityScore().withProximity(false);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testStabilityOptimizationNoEntity() {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().forStabilityScore().withSampling(false).withEncoding(false).withWeighting(false).withProximity(false);
        org.junit.jupiter.api.Assertions.assertThrows(AssertionError.class, () -> this.assertConfigOptimized(limeConfigOptimizer));
    }

    @Test
    void testWeightedStabilityOptimization() throws Exception {
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.5, 0.5);
        this.assertConfigOptimized(limeConfigOptimizer);
        limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.3, 0.7);
        this.assertConfigOptimized(limeConfigOptimizer);
        limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.7, 0.3);
        this.assertConfigOptimized(limeConfigOptimizer);
        limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(1.0, 0.0);
        this.assertConfigOptimized(limeConfigOptimizer);
        limeConfigOptimizer = new LimeConfigOptimizer().withTimeLimit(10L).withWeightedStability(0.0, 1.0);
        this.assertConfigOptimized(limeConfigOptimizer);
    }

    @Test
    void testWeightedStabilityWrongParamsOptimization() {
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(0.8, 0.7));
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(0.1, 0.7));
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(0.1, 1.1));
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(2.1, 0.1));
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(-0.1, 0.9));
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(0.1, -0.9));
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(0.1, 0.99));
        org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> new LimeConfigOptimizer().withWeightedStability(0.009, 0.99));
    }

    @Test
    void testSameConfig() throws ExecutionException, InterruptedException {
        long seed = 0L;
        ArrayList<LimeConfig> optimizedConfigs = new ArrayList<LimeConfig>();
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution((int)5, (int)100, (Random)new Random());
        List samples = dataDistribution.sample(3);
        List predictionOutputs = (List)model.predictAsync(samples).get();
        List predictions = DataUtils.getPredictions((List)samples, (List)predictionOutputs);
        for (int i = 0; i < 2; ++i) {
            Random random = new Random();
            LimeConfig initialConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(Long.valueOf(seed), random, 1));
            LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withDeterministicExecution(true).withStepCountLimit(10).withTimeLimit(10L);
            LimeConfig optimizedConfig = limeConfigOptimizer.optimize(initialConfig, predictions, model);
            optimizedConfigs.add(optimizedConfig);
        }
        LimeConfig first = (LimeConfig)optimizedConfigs.get(0);
        LimeConfig second = (LimeConfig)optimizedConfigs.get(1);
        AssertionsForClassTypes.assertThat((int)first.getNoOfRetries()).isEqualTo(second.getNoOfRetries());
        AssertionsForClassTypes.assertThat((int)first.getNoOfSamples()).isEqualTo(second.getNoOfSamples());
        AssertionsForClassTypes.assertThat((Object)first.getProximityFilteredDatasetMinimum()).isEqualTo((Object)second.getProximityFilteredDatasetMinimum());
        AssertionsForClassTypes.assertThat((double)first.getProximityKernelWidth()).isEqualTo(second.getProximityKernelWidth());
        AssertionsForClassTypes.assertThat((double)first.getProximityThreshold()).isEqualTo(second.getProximityThreshold());
        AssertionsForClassTypes.assertThat((boolean)first.isProximityFilter()).isEqualTo(second.isProximityFilter());
        AssertionsForClassTypes.assertThat((boolean)first.isAdaptDatasetVariance()).isEqualTo(second.isAdaptDatasetVariance());
        AssertionsForClassTypes.assertThat((boolean)first.isPenalizeBalanceSparse()).isEqualTo(second.isPenalizeBalanceSparse());
        AssertionsForClassTypes.assertThat((double)first.getEncodingParams().getNumericTypeClusterGaussianFilterWidth()).isEqualTo(second.getEncodingParams().getNumericTypeClusterGaussianFilterWidth());
        AssertionsForClassTypes.assertThat((double)first.getEncodingParams().getNumericTypeClusterThreshold()).isEqualTo(second.getEncodingParams().getNumericTypeClusterThreshold());
        AssertionsForClassTypes.assertThat((double)first.getSeparableDatasetRatio()).isEqualTo(second.getSeparableDatasetRatio());
        AssertionsForClassTypes.assertThat((int)first.getPerturbationContext().getNoOfPerturbations()).isEqualTo(second.getPerturbationContext().getNoOfPerturbations());
    }

    private void assertConfigOptimized(LimeConfigOptimizer limeConfigOptimizer) throws InterruptedException, ExecutionException {
        LimeConfig initialConfig = new LimeConfig().withSamples(10);
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        Random random = new Random();
        random.setSeed(4L);
        DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution((int)5, (int)100, (Random)random);
        List samples = dataDistribution.sample(10);
        List predictionOutputs = (List)model.predictAsync(samples).get();
        List predictions = DataUtils.getPredictions((List)samples, (List)predictionOutputs);
        LimeConfig optimizedConfig = limeConfigOptimizer.optimize(initialConfig, predictions, model);
        AssertionsForClassTypes.assertThat((Object)optimizedConfig).isNotNull();
        Assertions.assertThat((Object)optimizedConfig).isNotSameAs((Object)initialConfig);
    }
}

