/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.counterfactual;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualSolution;
import org.kie.kogito.explainability.local.counterfactual.SolverConfigBuilder;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.DataDomain;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.NumericFeatureDistribution;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionFeatureDomain;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain;
import org.kie.kogito.explainability.model.domain.EmptyFeatureDomain;
import org.kie.kogito.explainability.model.domain.FeatureDomain;
import org.kie.kogito.explainability.model.domain.NumericalFeatureDomain;
import org.kie.kogito.explainability.utils.DataUtils;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;
import org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore;
import org.optaplanner.core.api.solver.SolverJob;
import org.optaplanner.core.api.solver.SolverManager;
import org.optaplanner.core.config.solver.EnvironmentMode;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class CounterfactualExplainerTest {
    private static final Logger logger = LoggerFactory.getLogger(CounterfactualExplainerTest.class);
    final long predictionTimeOut = 10L;
    final TimeUnit predictionTimeUnit = TimeUnit.MINUTES;
    final Long steps = 30000L;
    final double DEFAULT_GOAL_THRESHOLD = 0.01;
    private static final Long MAX_RUNNING_TIME_SECONDS = 60L;
    private Function<SolverConfig, SolverManager<CounterfactualSolution, UUID>> solverManagerFactory;
    private SolverManager<CounterfactualSolution, UUID> solverManager;

    CounterfactualExplainerTest() {
    }

    @BeforeEach
    private void setup() {
        this.solverManagerFactory = (Function)Mockito.mock(Function.class);
        this.solverManager = (SolverManager)Mockito.mock(SolverManager.class);
    }

    private CounterfactualResult runCounterfactualSearch(Long randomSeed, List<Output> goal, List<Boolean> constraints, DataDomain dataDomain, List<Feature> features, PredictionProvider model, double goalThresold) throws InterruptedException, ExecutionException, TimeoutException {
        TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(this.steps);
        SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
        solverConfig.setRandomSeed(randomSeed);
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        CounterfactualConfig counterfactualConfig = new CounterfactualConfig();
        counterfactualConfig.withSolverConfig(solverConfig).withGoalThreshold(goalThresold);
        CounterfactualExplainer explainer = new CounterfactualExplainer(counterfactualConfig);
        PredictionInput input = new PredictionInput(features);
        PredictionOutput output = new PredictionOutput(goal);
        PredictionFeatureDomain domain = new PredictionFeatureDomain(dataDomain.getFeatureDomains());
        CounterfactualPrediction prediction = new CounterfactualPrediction(input, output, domain, constraints, null, UUID.randomUUID(), null);
        return (CounterfactualResult)explainer.explainAsync((Prediction)prediction, model).get(10L, this.predictionTimeUnit);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testNonEmptyInput(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("class", Type.NUMBER, new Value((Object)10.0), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        for (int i = 0; i < 4; ++i) {
            features.add(TestUtils.getMockedNumericFeature(i));
            featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
            constraints.add(false);
        }
        TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(10L));
        SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
        solverConfig.setRandomSeed(Long.valueOf(seed));
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionInput input = new PredictionInput(features);
        PredictionOutput output = new PredictionOutput(goal);
        CounterfactualPrediction prediction = new CounterfactualPrediction(input, output, new PredictionFeatureDomain(featureBoundaries), constraints, null, UUID.randomUUID(), null);
        CounterfactualResult counterfactualResult = (CounterfactualResult)counterfactualExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
            logger.debug("Entity: {}", (Object)entity);
        }
        logger.debug("Outputs: {}", (Object)((PredictionOutput)counterfactualResult.getOutput().get(0)).getOutputs());
        Assertions.assertNotNull((Object)counterfactualResult);
        Assertions.assertNotNull((Object)counterfactualResult.getEntities());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualMatch(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)150.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)1.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)2.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0), 0.01);
        double totalSum = 0.0;
        for (CounterfactualEntity entity : result.getEntities()) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        logger.debug("Outputs: {}", (Object)((PredictionOutput)result.getOutput().get(0)).getOutputs());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        Assertions.assertTrue((boolean)result.isValid());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualConstrainedMatchUnscaled(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(true);
        featureBoundaries.add(EmptyFeatureDomain.create());
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0));
        constraints.add(true);
        featureBoundaries.add(EmptyFeatureDomain.create());
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0), 0.01);
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(0)).isChanged());
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(3)).isChanged());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        Assertions.assertTrue((boolean)result.isValid());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualConstrainedMatchScaled(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        LinkedList<NumericFeatureDistribution> featureDistributions = new LinkedList<NumericFeatureDistribution>();
        Feature fnum1 = FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0);
        features.add(fnum1);
        constraints.add(true);
        featureBoundaries.add(EmptyFeatureDomain.create());
        featureDistributions.add(new NumericFeatureDistribution(fnum1, new NormalDistribution(500.0, 1.1).sample(1000)));
        Feature fnum2 = FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0);
        features.add(fnum2);
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        featureDistributions.add(new NumericFeatureDistribution(fnum2, new NormalDistribution(430.0, 1.7).sample(1000)));
        Feature fnum3 = FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0);
        features.add(fnum3);
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        featureDistributions.add(new NumericFeatureDistribution(fnum3, new NormalDistribution(470.0, 2.9).sample(1000)));
        Feature fnum4 = FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0);
        features.add(fnum4);
        constraints.add(true);
        featureBoundaries.add(EmptyFeatureDomain.create());
        featureDistributions.add(new NumericFeatureDistribution(fnum4, new NormalDistribution(2390.0, 0.3).sample(1000)));
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0), 0.01);
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(0)).isChanged());
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(3)).isChanged());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        Assertions.assertTrue((boolean)result.isValid());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualBoolean(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        for (int i = 0; i < 4; ++i) {
            features.add(TestUtils.getMockedNumericFeature(i));
            featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
            constraints.add(false);
        }
        features.add(FeatureFactory.newBooleanFeature((String)"f-bool", (Boolean)true));
        featureBoundaries.add(EmptyFeatureDomain.create());
        constraints.add(false);
        constraints.set(2, true);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0), 0.01);
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(2)).isChanged());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        Assertions.assertTrue((boolean)result.isValid());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualCategoricalStrictFail(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("result", Type.NUMBER, new Value((Object)25.0), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"x-1", (Number)5.0));
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)100.0));
        constraints.add(false);
        features.add(FeatureFactory.newNumericalFeature((String)"x-2", (Number)40.0));
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)100.0));
        constraints.add(false);
        features.add(FeatureFactory.newCategoricalFeature((String)"operand", (String)"*"));
        featureBoundaries.add(CategoricalFeatureDomain.create((String[])new String[]{"+", "-", "/", "*"}));
        constraints.add(false);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, TestUtils.getSymbolicArithmeticModel(), 0.0);
        List counterfactualEntities = result.getEntities();
        Stream<Feature> counterfactualFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature);
        String operand = counterfactualFeatures.filter(feature -> feature.getName().equals("operand")).findFirst().get().getValue().asString();
        List numericalFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).filter(feature -> !feature.getName().equals("operand")).collect(Collectors.toList());
        double opResult = 0.0;
        for (Feature feature2 : numericalFeatures) {
            switch (operand) {
                case "+": {
                    opResult += feature2.getValue().asNumber();
                    break;
                }
                case "-": {
                    opResult -= feature2.getValue().asNumber();
                    break;
                }
                case "*": {
                    opResult *= feature2.getValue().asNumber();
                    break;
                }
                case "/": {
                    opResult /= feature2.getValue().asNumber();
                }
            }
        }
        double epsilon = 0.1;
        Assertions.assertFalse((boolean)result.isValid());
        Assertions.assertTrue((opResult <= 25.1 ? 1 : 0) != 0);
        Assertions.assertTrue((opResult >= 24.9 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualCategoricalNotStrict(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("result", Type.NUMBER, new Value((Object)25.0), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"x-1", (Number)5.0));
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)100.0));
        constraints.add(false);
        features.add(FeatureFactory.newNumericalFeature((String)"x-2", (Number)40.0));
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)100.0));
        constraints.add(false);
        features.add(FeatureFactory.newCategoricalFeature((String)"operand", (String)"*"));
        featureBoundaries.add(CategoricalFeatureDomain.create((String[])new String[]{"+", "-", "/", "*"}));
        constraints.add(false);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, TestUtils.getSymbolicArithmeticModel(), 0.01);
        List counterfactualEntities = result.getEntities();
        Stream<Feature> counterfactualFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature);
        String operand = counterfactualFeatures.filter(feature -> feature.getName().equals("operand")).findFirst().get().getValue().asString();
        List numericalFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).filter(feature -> !feature.getName().equals("operand")).collect(Collectors.toList());
        double opResult = 0.0;
        for (Feature feature2 : numericalFeatures) {
            switch (operand) {
                case "+": {
                    opResult += feature2.getValue().asNumber();
                    break;
                }
                case "-": {
                    opResult -= feature2.getValue().asNumber();
                    break;
                }
                case "*": {
                    opResult *= feature2.getValue().asNumber();
                    break;
                }
                case "/": {
                    opResult /= feature2.getValue().asNumber();
                }
            }
        }
        double epsilon = 0.5;
        Assertions.assertTrue((boolean)result.isValid());
        Assertions.assertTrue((opResult <= 25.5 ? 1 : 0) != 0);
        Assertions.assertTrue((opResult >= 24.5 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualMatchThreshold(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        double scoreThreshold = 0.9;
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.9));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        PredictionProvider model = TestUtils.getSumThresholdModel(500.0, 10.0);
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, model, 0.01);
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        List cfFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        PredictionInput cfInput = new PredictionInput(cfFeatures);
        PredictionOutput cfOutput = (PredictionOutput)((List)model.predictAsync(List.of(cfInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        double predictionScore = ((Output)cfOutput.getOutputs().get(0)).getScore();
        logger.debug("Prediction score: {}", (Object)predictionScore);
        Assertions.assertTrue((predictionScore >= 0.9 ? 1 : 0) != 0);
        Assertions.assertTrue((boolean)result.isValid());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testCounterfactualMatchNoThreshold(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        double scoreThreshold = 0.0;
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        PredictionProvider model = TestUtils.getSumThresholdModel(500.0, 10.0);
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, model, 0.01);
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        List cfFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        PredictionInput cfInput = new PredictionInput(cfFeatures);
        PredictionOutput cfOutput = (PredictionOutput)((List)model.predictAsync(List.of(cfInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        double predictionScore = ((Output)cfOutput.getOutputs().get(0)).getScore();
        logger.debug("Prediction score: {}", (Object)predictionScore);
        Assertions.assertTrue((predictionScore < 0.5 ? 1 : 0) != 0);
        Assertions.assertTrue((boolean)result.isValid());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testNoCounterfactualPossible(long seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        PerturbationContext perturbationContext = new PerturbationContext(Long.valueOf(seed), random, 4);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)1.0));
        constraints.add(true);
        featureBoundaries.add(EmptyFeatureDomain.create());
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)1.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)2.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)1.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)2.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)1.0));
        constraints.add(true);
        featureBoundaries.add(EmptyFeatureDomain.create());
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 1.0;
        List perturbedFeatures = DataUtils.perturbFeatures(features, (PerturbationContext)perturbationContext);
        CounterfactualResult result = this.runCounterfactualSearch(seed, goal, constraints, dataDomain, perturbedFeatures, TestUtils.getSumThresholdModel(500.0, 1.0), 0.01);
        Assertions.assertFalse((boolean)result.isValid());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testConsumers(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)1000.0));
        TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(10000L));
        SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
        solverConfig.setRandomSeed(Long.valueOf(seed));
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        Consumer assertIntermediateCounterfactualNotNull = (Consumer)Mockito.mock(Consumer.class);
        CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig).withGoalThreshold(0.01);
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
        PredictionInput input = new PredictionInput(features);
        double center = 500.0;
        double epsilon = 10.0;
        PredictionProvider model = TestUtils.getSumThresholdModel(500.0, 10.0);
        PredictionOutput output = new PredictionOutput(goal);
        CounterfactualPrediction prediction = new CounterfactualPrediction(input, output, new PredictionFeatureDomain(featureBoundaries), constraints, null, UUID.randomUUID(), null);
        CounterfactualResult counterfactualResult = (CounterfactualResult)counterfactualExplainer.explainAsync((Prediction)prediction, model, assertIntermediateCounterfactualNotNull).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
            logger.debug("Entity: {}", (Object)entity);
        }
        logger.debug("Outputs: {}", (Object)((PredictionOutput)counterfactualResult.getOutput().get(0)).getOutputs());
        ((Consumer)Mockito.verify((Object)assertIntermediateCounterfactualNotNull, (VerificationMode)Mockito.atLeast((int)1))).accept((CounterfactualResult)ArgumentMatchers.any());
    }

    @ParameterizedTest
    @ValueSource(ints={1, 2, 3, 5, 8})
    void testSequenceIds(int numberOfIntermediateSolutions) throws ExecutionException, InterruptedException, TimeoutException {
        ArrayList<Long> sequenceIds = new ArrayList<Long>();
        Consumer<CounterfactualResult> captureSequenceIds = counterfactual -> sequenceIds.add(counterfactual.getSequenceId());
        ArgumentCaptor intermediateSolutionConsumerCaptor = ArgumentCaptor.forClass(Consumer.class);
        CounterfactualResult result = this.mockExplainerInvocation(captureSequenceIds, null);
        ((SolverManager)Mockito.verify(this.solverManager)).solveAndListen((Object)((UUID)ArgumentMatchers.any()), (Function)ArgumentMatchers.any(), (Consumer)intermediateSolutionConsumerCaptor.capture(), (BiConsumer)ArgumentMatchers.any());
        Consumer intermediateSolutionConsumer = (Consumer)intermediateSolutionConsumerCaptor.getValue();
        IntStream.range(0, numberOfIntermediateSolutions).forEach(i -> {
            CounterfactualSolution intermediate = (CounterfactualSolution)Mockito.mock(CounterfactualSolution.class);
            BendableBigDecimalScore intermediateScore = BendableBigDecimalScore.zero((int)0, (int)0);
            Mockito.when((Object)intermediate.getScore()).thenReturn((Object)intermediateScore);
            intermediateSolutionConsumer.accept(intermediate);
        });
        sequenceIds.add(result.getSequenceId());
        Assertions.assertEquals((int)(numberOfIntermediateSolutions + 1), (int)sequenceIds.size());
        Assertions.assertEquals((int)(numberOfIntermediateSolutions + 1), (int)((int)sequenceIds.stream().distinct().count()));
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testIntermediateUniqueIds(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        double center = 400.0;
        double epsilon = 10.0;
        PredictionProvider model = TestUtils.getSumThresholdModel(400.0, 10.0);
        TerminationConfig terminationConfig = new TerminationConfig().withBestScoreFeasible(Boolean.valueOf(true)).withScoreCalculationCountLimit(Long.valueOf(10000L));
        SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
        solverConfig.setRandomSeed(Long.valueOf(seed));
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        ArrayList intermediateIds = new ArrayList();
        ArrayList executionIds = new ArrayList();
        Consumer<CounterfactualResult> captureIntermediateIds = counterfactual -> intermediateIds.add(counterfactual.getSolutionId());
        Consumer<CounterfactualResult> captureExecutionIds = counterfactual -> executionIds.add(counterfactual.getExecutionId());
        CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
        PredictionInput input = new PredictionInput(features);
        PredictionOutput output = new PredictionOutput(goal);
        UUID executionId = UUID.randomUUID();
        CounterfactualPrediction prediction = new CounterfactualPrediction(input, output, new PredictionFeatureDomain(featureBoundaries), constraints, null, executionId, null);
        CounterfactualResult counterfactualResult = (CounterfactualResult)counterfactualExplainer.explainAsync((Prediction)prediction, model, captureIntermediateIds.andThen(captureExecutionIds)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertEquals((int)((int)intermediateIds.stream().distinct().count()), (int)intermediateIds.size());
        Assertions.assertEquals((int)1, (int)((int)executionIds.stream().distinct().count()));
        Assertions.assertEquals(executionIds.get(0), (Object)executionId);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testFinalUniqueIds(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.5));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10000.0));
        double center = 400.0;
        double epsilon = 10.0;
        PredictionProvider model = TestUtils.getSumThresholdModel(400.0, 10.0);
        TerminationConfig terminationConfig = new TerminationConfig().withBestScoreFeasible(Boolean.valueOf(true)).withScoreCalculationCountLimit(Long.valueOf(10000L));
        SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
        solverConfig.setRandomSeed(Long.valueOf(seed));
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
        ArrayList intermediateIds = new ArrayList();
        ArrayList executionIds = new ArrayList();
        Consumer<CounterfactualResult> captureIntermediateIds = counterfactual -> intermediateIds.add(counterfactual.getSolutionId());
        Consumer<CounterfactualResult> captureExecutionIds = counterfactual -> executionIds.add(counterfactual.getExecutionId());
        CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
        PredictionInput input = new PredictionInput(features);
        PredictionOutput output = new PredictionOutput(goal);
        UUID executionId = UUID.randomUUID();
        CounterfactualPrediction prediction = new CounterfactualPrediction(input, output, new PredictionFeatureDomain(featureBoundaries), constraints, null, executionId, null);
        CounterfactualResult counterfactualResult = (CounterfactualResult)counterfactualExplainer.explainAsync((Prediction)prediction, model, captureIntermediateIds.andThen(captureExecutionIds)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertEquals((int)((int)intermediateIds.stream().distinct().count()), (int)intermediateIds.size());
        Assertions.assertTrue((intermediateIds.size() > 0 ? 1 : 0) != 0);
        Assertions.assertTrue((executionIds.size() > 0 ? 1 : 0) != 0);
        Assertions.assertEquals((int)executionIds.size(), (int)intermediateIds.size());
        Assertions.assertEquals((int)1, (int)((int)executionIds.stream().distinct().count()));
        Assertions.assertNotEquals(intermediateIds.get(intermediateIds.size() - 1), (Object)counterfactualResult.getSolutionId());
        Assertions.assertEquals(executionIds.get(0), (Object)executionId);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testSparsity(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        ArrayList<Feature> features = new ArrayList<Feature>();
        ArrayList<FeatureDomain> featureBoundaries = new ArrayList<FeatureDomain>();
        ArrayList<Boolean> constraints = new ArrayList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)0));
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10.0));
        constraints.add(false);
        featureBoundaries.add(NumericalFeatureDomain.create((double)0.0, (double)10.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)5));
        constraints.add(false);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 10.0;
        double epsilon = 0.1;
        CounterfactualResult result = this.runCounterfactualSearch(Long.valueOf(seed), goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(10.0, 0.1), 0.01);
        Assertions.assertTrue((!((CounterfactualEntity)result.getEntities().get(0)).isChanged() || !((CounterfactualEntity)result.getEntities().get(1)).isChanged() ? 1 : 0) != 0);
        Assertions.assertTrue((boolean)result.isValid());
    }

    @Test
    void testTerminationSpentLimitWhenDefined() throws ExecutionException, InterruptedException, TimeoutException {
        ArgumentCaptor solverConfigArgumentCaptor = ArgumentCaptor.forClass(SolverConfig.class);
        this.mockExplainerInvocation((Consumer)Mockito.mock(Consumer.class), MAX_RUNNING_TIME_SECONDS);
        ((Function)Mockito.verify(this.solverManagerFactory)).apply((SolverConfig)solverConfigArgumentCaptor.capture());
        SolverConfig solverConfig = (SolverConfig)solverConfigArgumentCaptor.getValue();
        Assertions.assertEquals((Long)MAX_RUNNING_TIME_SECONDS, (long)solverConfig.getTerminationConfig().getSpentLimit().getSeconds());
    }

    @Test
    void testTerminationSpentLimitWhenUndefined() throws ExecutionException, InterruptedException, TimeoutException {
        ArgumentCaptor solverConfigArgumentCaptor = ArgumentCaptor.forClass(SolverConfig.class);
        this.mockExplainerInvocation((Consumer)Mockito.mock(Consumer.class), null);
        ((Function)Mockito.verify(this.solverManagerFactory)).apply((SolverConfig)solverConfigArgumentCaptor.capture());
        SolverConfig solverConfig = (SolverConfig)solverConfigArgumentCaptor.getValue();
        Assertions.assertNull((Object)solverConfig.getTerminationConfig().getSecondsSpentLimit());
    }

    CounterfactualResult mockExplainerInvocation(Consumer<CounterfactualResult> intermediateResultsConsumer, Long maxRunningTimeSeconds) throws ExecutionException, InterruptedException, TimeoutException {
        SolverJob solverJob = (SolverJob)Mockito.mock(SolverJob.class);
        CounterfactualSolution solution = (CounterfactualSolution)Mockito.mock(CounterfactualSolution.class);
        BendableBigDecimalScore score = BendableBigDecimalScore.zero((int)0, (int)0);
        Mockito.when((Object)this.solverManager.solveAndListen((Object)((UUID)ArgumentMatchers.any()), (Function)ArgumentMatchers.any(), (Consumer)ArgumentMatchers.any(), (BiConsumer)ArgumentMatchers.any())).thenReturn((Object)solverJob);
        Mockito.when((Object)((CounterfactualSolution)solverJob.getFinalBestSolution())).thenReturn((Object)solution);
        Mockito.when((Object)solution.getScore()).thenReturn((Object)score);
        Mockito.when(this.solverManagerFactory.apply((SolverConfig)ArgumentMatchers.any())).thenReturn(this.solverManager);
        CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverManagerFactory(this.solverManagerFactory);
        CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
        CounterfactualPrediction prediction = new CounterfactualPrediction(new PredictionInput(Collections.emptyList()), new PredictionOutput(Collections.emptyList()), new PredictionFeatureDomain(Collections.emptyList()), Collections.emptyList(), null, UUID.randomUUID(), maxRunningTimeSeconds);
        return (CounterfactualResult)counterfactualExplainer.explainAsync((Prediction)prediction, inputs -> CompletableFuture.completedFuture(Collections.emptyList()), intermediateResultsConsumer).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    }
}

