/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.counterfactual;

import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualConfigurationFactory;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.DataDomain;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureDomain;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution;
import org.kie.kogito.explainability.model.NumericFeatureDistribution;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class CounterfactualExplainerTest {
    final long predictionTimeOut = 10L;
    final TimeUnit predictionTimeUnit = TimeUnit.MINUTES;
    final Long steps = 200000L;
    private static final Logger logger = LoggerFactory.getLogger(CounterfactualExplainerTest.class);

    CounterfactualExplainerTest() {
    }

    private CounterfactualResult runCounterfactualSearch(List<Output> goal, List<Boolean> constraints, DataDomain dataDomain, List<Feature> features, PredictionProvider model) throws InterruptedException, ExecutionException, TimeoutException {
        TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(this.steps);
        SolverConfig solverConfig = CounterfactualConfigurationFactory.builder().withTerminationConfig(terminationConfig).build();
        CounterfactualExplainer explainer = CounterfactualExplainer.builder(goal, constraints, (DataDomain)dataDomain).withSolverConfig(solverConfig).build();
        PredictionInput input = new PredictionInput(features);
        PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        Prediction prediction = new Prediction(input, output);
        return (CounterfactualResult)explainer.explainAsync(prediction, model).get(10L, this.predictionTimeUnit);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testNonEmptyInput(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("class", Type.BOOLEAN, new Value((Object)false), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        for (int i = 0; i < 4; ++i) {
            features.add(TestUtils.getMockedNumericFeature(i));
            featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
            constraints.add(false);
        }
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(10L));
        SolverConfig solverConfig = CounterfactualConfigurationFactory.builder().withTerminationConfig(terminationConfig).build();
        CounterfactualExplainer counterfactualExplainer = CounterfactualExplainer.builder(goal, constraints, (DataDomain)dataDomain).withSolverConfig(solverConfig).build();
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = (PredictionOutput)((List)model.predictAsync(List.of(input)).get(10L, this.predictionTimeUnit)).get(0);
        Prediction prediction = new Prediction(input, output);
        CounterfactualResult counterfactualResult = (CounterfactualResult)counterfactualExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
            logger.debug("Entity: {}", (Object)entity);
        }
        logger.debug("Outputs: {}", (Object)((PredictionOutput)counterfactualResult.getOutput().get(0)).getOutputs());
        Assertions.assertNotNull((Object)counterfactualResult);
        Assertions.assertNotNull((Object)counterfactualResult.getEntities());
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testCounterfactualMatch(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)150.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)1.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)2.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0));
        double totalSum = 0.0;
        for (CounterfactualEntity entity : result.getEntities()) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        logger.debug("Outputs: {}", (Object)((PredictionOutput)result.getOutput().get(0)).getOutputs());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testCounterfactualConstrainedMatchUnscaled(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        constraints.set(0, true);
        constraints.set(3, true);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0));
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(0)).isChanged());
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(3)).isChanged());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testCounterfactualConstrainedMatchScaled(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        LinkedList<NumericFeatureDistribution> featureDistributions = new LinkedList<NumericFeatureDistribution>();
        Feature fnum1 = FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0);
        features.add(fnum1);
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        featureDistributions.add(new NumericFeatureDistribution(fnum1, new NormalDistribution(500.0, 1.1).sample(1000)));
        Feature fnum2 = FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0);
        features.add(fnum2);
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        featureDistributions.add(new NumericFeatureDistribution(fnum2, new NormalDistribution(430.0, 1.7).sample(1000)));
        Feature fnum3 = FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0);
        features.add(fnum3);
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        featureDistributions.add(new NumericFeatureDistribution(fnum3, new NormalDistribution(470.0, 2.9).sample(1000)));
        Feature fnum4 = FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0);
        features.add(fnum4);
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        featureDistributions.add(new NumericFeatureDistribution(fnum4, new NormalDistribution(2390.0, 0.3).sample(1000)));
        IndependentFeaturesDataDistribution dataDistribution = new IndependentFeaturesDataDistribution(featureDistributions);
        constraints.set(0, true);
        constraints.set(3, true);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0));
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(0)).isChanged());
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(3)).isChanged());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testCounterfactualBoolean(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        for (int i = 0; i < 4; ++i) {
            features.add(TestUtils.getMockedNumericFeature(i));
            featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
            constraints.add(false);
        }
        features.add(FeatureFactory.newBooleanFeature((String)"f-bool", (Boolean)true));
        featureBoundaries.add(null);
        constraints.add(false);
        constraints.set(2, true);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        CounterfactualResult result = this.runCounterfactualSearch(goal, constraints, dataDomain, features, TestUtils.getSumThresholdModel(500.0, 10.0));
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertFalse((boolean)((CounterfactualEntity)counterfactualEntities.get(2)).isChanged());
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testCounterfactualCategorical(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        List<Output> goal = List.of(new Output("result", Type.NUMBER, new Value((Object)25.0), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"x-1", (Number)5.0));
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)100.0));
        constraints.add(false);
        features.add(FeatureFactory.newNumericalFeature((String)"x-2", (Number)40.0));
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)100.0));
        constraints.add(false);
        features.add(FeatureFactory.newCategoricalFeature((String)"operand", (String)"*"));
        featureBoundaries.add(FeatureDomain.categorical((String[])new String[]{"+", "-", "/", "*"}));
        constraints.add(false);
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        CounterfactualResult result = this.runCounterfactualSearch(goal, constraints, dataDomain, features, TestUtils.getSymbolicArithmeticModel());
        List counterfactualEntities = result.getEntities();
        Stream<Feature> counterfactualFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature);
        String operand = counterfactualFeatures.filter(feature -> feature.getName().equals("operand")).findFirst().get().getValue().asString();
        List numericalFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).filter(feature -> !feature.getName().equals("operand")).collect(Collectors.toList());
        double opResult = 0.0;
        for (Feature feature2 : numericalFeatures) {
            switch (operand) {
                case "+": {
                    opResult += feature2.getValue().asNumber();
                    break;
                }
                case "-": {
                    opResult -= feature2.getValue().asNumber();
                    break;
                }
                case "*": {
                    opResult *= feature2.getValue().asNumber();
                    break;
                }
                case "/": {
                    opResult /= feature2.getValue().asNumber();
                }
            }
        }
        double epsilon = 0.01;
        Assertions.assertTrue((opResult <= 25.01 ? 1 : 0) != 0);
        Assertions.assertTrue((opResult >= 24.99 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testCounterfactualMatchThreshold(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        double scoreThreshold = 0.9;
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.9));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        PredictionProvider model = TestUtils.getSumThresholdModel(500.0, 10.0);
        CounterfactualResult result = this.runCounterfactualSearch(goal, constraints, dataDomain, features, model);
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        List cfFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        PredictionInput cfInput = new PredictionInput(cfFeatures);
        PredictionOutput cfOutput = (PredictionOutput)((List)model.predictAsync(List.of(cfInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        double predictionScore = ((Output)cfOutput.getOutputs().get(0)).getScore();
        logger.debug("Prediction score: {}", (Object)predictionScore);
        Assertions.assertTrue((predictionScore >= 0.9 ? 1 : 0) != 0);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2, 3, 4})
    void testCounterfactualMatchNoThreshold(int seed) throws ExecutionException, InterruptedException, TimeoutException {
        Random random = new Random();
        random.setSeed(seed);
        double scoreThreshold = 0.0;
        List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value((Object)true), 0.0));
        LinkedList<Feature> features = new LinkedList<Feature>();
        LinkedList<FeatureDomain> featureBoundaries = new LinkedList<FeatureDomain>();
        LinkedList<Boolean> constraints = new LinkedList<Boolean>();
        features.add(FeatureFactory.newNumericalFeature((String)"f-num1", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num2", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num3", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        features.add(FeatureFactory.newNumericalFeature((String)"f-num4", (Number)100.0));
        constraints.add(false);
        featureBoundaries.add(FeatureDomain.numerical((double)0.0, (double)1000.0));
        DataDomain dataDomain = new DataDomain(featureBoundaries);
        double center = 500.0;
        double epsilon = 10.0;
        PredictionProvider model = TestUtils.getSumThresholdModel(500.0, 10.0);
        CounterfactualResult result = this.runCounterfactualSearch(goal, constraints, dataDomain, features, model);
        List counterfactualEntities = result.getEntities();
        double totalSum = 0.0;
        for (CounterfactualEntity entity : counterfactualEntities) {
            totalSum += entity.asFeature().getValue().asNumber();
            logger.debug("Entity: {}", (Object)entity);
        }
        Assertions.assertTrue((totalSum <= 510.0 ? 1 : 0) != 0);
        Assertions.assertTrue((totalSum >= 490.0 ? 1 : 0) != 0);
        List cfFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        PredictionInput cfInput = new PredictionInput(cfFeatures);
        PredictionOutput cfOutput = (PredictionOutput)((List)model.predictAsync(List.of(cfInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())).get(0);
        double predictionScore = ((Output)cfOutput.getOutputs().get(0)).getScore();
        logger.debug("Prediction score: {}", (Object)predictionScore);
        Assertions.assertTrue((predictionScore < 0.1 ? 1 : 0) != 0);
    }
}

