/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.local.shap.ShapConfig;
import org.kie.kogito.explainability.local.shap.ShapKernelExplainer;
import org.kie.kogito.explainability.local.shap.ShapResults;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;

class ShapKernelExplainerTest {
    double[][] backgroundRaw = new double[][]{{1.0, 2.0, 3.0, -4.0, 5.0}, {10.0, 11.0, 12.0, -4.0, 13.0}, {2.0, 3.0, 4.0, -4.0, 6.0}};
    double[][] toExplainRaw = new double[][]{{5.0, 6.0, 7.0, -4.0, 8.0}, {11.0, 12.0, 13.0, -5.0, 14.0}, {0.0, 0.0, 1.0, 4.0, 2.0}};
    double[][] backgroundNoVariance = new double[][]{{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}};
    double[][] toExplainZeroVariance = new double[][]{{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}};
    double[][][] zeroVarianceOneOutputSHAP = new double[][][]{new double[][]{{0.0, 0.0, 0.0}}, new double[][]{{0.0, 0.0, 0.0}}};
    double[][][] zeroVarianceMultiOutputSHAP = new double[][][]{new double[][]{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}, new double[][]{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}};
    double[][] toExplainOneVariance = new double[][]{{3.0, 2.0, 3.0}, {1.0, 2.0, 2.0}};
    double[][][] oneVarianceOneOutputSHAP = new double[][][]{new double[][]{{2.0, 0.0, 0.0}}, new double[][]{{0.0, 0.0, -1.0}}};
    double[][][] oneVarianceMultiOutputSHAP = new double[][][]{new double[][]{{2.0, 0.0, 0.0}, {4.0, 0.0, 0.0}}, new double[][]{{0.0, 0.0, -1.0}, {0.0, 0.0, -2.0}}};
    double[][] toExplainLogit = new double[][]{{0.1, 0.12, 0.14, -0.08, 0.16}, {0.22, 0.24, 0.26, -0.1, 0.38}, {-0.1, 0.0, 0.02, 0.1, 0.04}};
    double[][] backgroundLogit = new double[][]{{0.02380952, 0.04761905, 0.07142857, -0.0952381, 0.11904762}, {0.23809524, 0.26190476, 0.28571429, -0.0952381, 0.30952381}, {0.04761905, 0.07142857, 0.11904762, -0.0952381, 0.14285714}};
    double[][][] logitSHAP = new double[][][]{new double[][]{{-0.01420862, 0.0, -0.08377778, 0.06825253, -0.13625127}}, new double[][]{{0.50970797, 0.0, 0.44412765, -0.02169177, 0.80832232}}, new double[][]{{Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN}}};
    double[][][] multiVarianceOneOutputSHAP = new double[][][]{new double[][]{{0.66666667, 0.0, 0.66666667, 0.0, 0.0}}, new double[][]{{6.66666667, 0.0, 6.66666667, -1.0, 6.0}}, new double[][]{{-4.33333333, 0.0, -5.33333333, 8.0, -6.0}}};
    double[][][] multiVarianceMultiOutputSHAP = new double[][][]{new double[][]{{0.66666667, 0.0, 0.66666667, 0.0, 0.0}, {1.333333333, 0.0, 1.33333333, 0.0, 0.0}}, new double[][]{{6.66666667, 0.0, 6.66666667, -1.0, 6.0}, {13.333333333, 0.0, 13.333333333, -2.0, 12.0}}, new double[][]{{-4.33333333, 0.0, -5.33333333, 8.0, -6.0}, {-8.6666666667, 0.0, -10.666666666, 16.0, -12.0}}};
    PerturbationContext pc = new PerturbationContext(new Random(0L), 0);
    ShapConfig.Builder testConfig = ShapConfig.builder().withLink(ShapConfig.LinkType.IDENTITY).withPC(this.pc);
    double[][] backgroundAllZeros = new double[100][6];
    double[][] toExplainAllOnes = new double[][]{{1.0, 1.0, 1.0, 1.0, 1.0, 1.0}};

    ShapKernelExplainerTest() {
    }

    private List<PredictionInput> createPIFromMatrix(double[][] m) {
        ArrayList<PredictionInput> pis = new ArrayList<PredictionInput>();
        int[] shape = new int[]{m.length, m[0].length};
        for (int i = 0; i < shape[0]; ++i) {
            ArrayList<Feature> fs = new ArrayList<Feature>();
            for (int j = 0; j < shape[1]; ++j) {
                fs.add(FeatureFactory.newNumericalFeature((String)"f", (Number)m[i][j]));
            }
            pis.add(new PredictionInput(fs));
        }
        return pis;
    }

    private RealMatrix[] saliencyToMatrix(Saliency[] saliencies) {
        RealMatrix emptyMatrix = MatrixUtils.createRealMatrix((double[][])new double[saliencies.length][saliencies[0].getPerFeatureImportance().size()]);
        RealMatrix[] out = new RealMatrix[]{emptyMatrix.copy(), emptyMatrix.copy()};
        for (int i = 0; i < saliencies.length; ++i) {
            List fis = saliencies[i].getPerFeatureImportance();
            for (int j = 0; j < fis.size(); ++j) {
                out[0].setEntry(i, j, ((FeatureImportance)fis.get(j)).getScore());
                out[1].setEntry(i, j, ((FeatureImportance)fis.get(j)).getConfidence());
            }
        }
        return out;
    }

    private void shapTestCase(PredictionProvider model, ShapConfig skConfig, double[][] toExplainRaw, double[][][] expected) throws InterruptedException, TimeoutException, ExecutionException {
        List<PredictionInput> toExplain = this.createPIFromMatrix(toExplainRaw);
        List predictionOutputs = (List)model.predictAsync(toExplain).get(5L, TimeUnit.SECONDS);
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        for (int i = 0; i < predictionOutputs.size(); ++i) {
            predictions.add(new SimplePrediction(toExplain.get(i), (PredictionOutput)predictionOutputs.get(i)));
        }
        ShapKernelExplainer ske = new ShapKernelExplainer(skConfig);
        for (int i = 0; i < toExplain.size(); ++i) {
            Saliency[] explanationSaliencies = ((ShapResults)ske.explainAsync((Prediction)predictions.get(i), model).get(5L, TimeUnit.SECONDS)).getSaliencies();
            RealMatrix explanations = this.saliencyToMatrix(explanationSaliencies)[0];
            for (int j = 0; j < explanations.getRowDimension(); ++j) {
                Assertions.assertArrayEquals((double[])expected[i][j], (double[])explanations.getRow(j), (double)1.0E-6);
            }
        }
    }

    private void shapTestCase(PredictionProvider model, ShapKernelExplainer ske, double[][] toExplainRaw, double[][][] expected) throws InterruptedException, TimeoutException, ExecutionException {
        int i;
        List<PredictionInput> toExplain = this.createPIFromMatrix(toExplainRaw);
        List predictionOutputs = (List)model.predictAsync(toExplain).get(5L, TimeUnit.SECONDS);
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        for (i = 0; i < predictionOutputs.size(); ++i) {
            predictions.add(new SimplePrediction(toExplain.get(i), (PredictionOutput)predictionOutputs.get(i)));
        }
        for (i = 0; i < toExplain.size(); ++i) {
            Saliency[] explanationSaliencies = ((ShapResults)ske.explainAsync((Prediction)predictions.get(i), model).get(5L, TimeUnit.SECONDS)).getSaliencies();
            RealMatrix explanations = this.saliencyToMatrix(explanationSaliencies)[0];
            for (int j = 0; j < explanations.getRowDimension(); ++j) {
                Assertions.assertArrayEquals((double[])expected[i][j], (double[])explanations.getRow(j), (double)1.0E-6);
            }
        }
    }

    @Test
    void testNoVarianceOneOutput() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        List<PredictionInput> background = this.createPIFromMatrix(this.backgroundNoVariance);
        ShapConfig skConfig = this.testConfig.withBackground(background).withNSamples(Integer.valueOf(100)).build();
        this.shapTestCase(model, skConfig, this.toExplainZeroVariance, this.zeroVarianceOneOutputSHAP);
    }

    @Test
    void testOneVarianceOneOutput() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        List<PredictionInput> background = this.createPIFromMatrix(this.backgroundNoVariance);
        ShapConfig skConfig = this.testConfig.withBackground(background).withNSamples(Integer.valueOf(100)).build();
        this.shapTestCase(model, skConfig, this.toExplainOneVariance, this.oneVarianceOneOutputSHAP);
    }

    @Test
    void testMultiVarianceOneOutput() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        List<PredictionInput> background = this.createPIFromMatrix(this.backgroundRaw);
        ShapConfig skConfig = this.testConfig.withBackground(background).withNSamples(Integer.valueOf(35)).build();
        this.shapTestCase(model, skConfig, this.toExplainRaw, this.multiVarianceOneOutputSHAP);
    }

    @Test
    void testMultiVarianceOneOutputLogit() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        List<PredictionInput> background = this.createPIFromMatrix(this.backgroundLogit);
        ShapConfig skConfig = ShapConfig.builder().withBackground(background).withLink(ShapConfig.LinkType.LOGIT).withNSamples(Integer.valueOf(100)).withPC(this.pc).build();
        this.shapTestCase(model, skConfig, this.toExplainLogit, this.logitSHAP);
    }

    @Test
    void testNoVarianceMultiOutput() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipTwoOutputModel(1);
        List<PredictionInput> background = this.createPIFromMatrix(this.backgroundNoVariance);
        ShapConfig skConfig = this.testConfig.withBackground(background).build();
        this.shapTestCase(model, skConfig, this.toExplainZeroVariance, this.zeroVarianceMultiOutputSHAP);
    }

    @Test
    void testOneVarianceMultiOutput() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipTwoOutputModel(1);
        List<PredictionInput> background = this.createPIFromMatrix(this.backgroundNoVariance);
        ShapConfig skConfig = this.testConfig.withBackground(background).build();
        this.shapTestCase(model, skConfig, this.toExplainOneVariance, this.oneVarianceMultiOutputSHAP);
    }

    @Test
    void testMultiVarianceMultiOutput() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipTwoOutputModel(1);
        List<PredictionInput> background = this.createPIFromMatrix(this.backgroundRaw);
        ShapConfig skConfig = this.testConfig.withBackground(background).build();
        this.shapTestCase(model, skConfig, this.toExplainRaw, this.multiVarianceMultiOutputSHAP);
    }

    @Test
    void testLargeBackground() throws InterruptedException, TimeoutException, ExecutionException {
        double[][] largeBackground = new double[100][10];
        for (int i = 0; i < 100; ++i) {
            for (int j = 0; j < 10; ++j) {
                largeBackground[i][j] = (double)i / 100.0 + (double)j;
            }
        }
        double[][] toExplainLargeBackground = new double[][]{{0.0, 1.0, -2.0, 3.5, -4.1, 5.5, -12.0, 0.8, 0.11, 15.0}};
        double[][][] expected = new double[][][]{new double[][]{{-0.495, 0.0, -4.495, 0.005, -8.595, 0.005, -18.495, -6.695, -8.385, 5.505}}};
        List<PredictionInput> background = this.createPIFromMatrix(largeBackground);
        List<PredictionInput> toExplain = this.createPIFromMatrix(toExplainLargeBackground);
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        ShapConfig skConfig = this.testConfig.withBackground(background).build();
        List predictionOutputs = (List)model.predictAsync(toExplain).get();
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        for (int i = 0; i < predictionOutputs.size(); ++i) {
            predictions.add(new SimplePrediction(toExplain.get(i), (PredictionOutput)predictionOutputs.get(i)));
        }
        ShapKernelExplainer ske = new ShapKernelExplainer(skConfig);
        for (int i = 0; i < toExplain.size(); ++i) {
            Saliency[] explanationSaliencies = ((ShapResults)ske.explainAsync((Prediction)predictions.get(i), model).get(5L, TimeUnit.SECONDS)).getSaliencies();
            RealMatrix[] explanationsAndConfs = this.saliencyToMatrix(explanationSaliencies);
            RealMatrix explanations = explanationsAndConfs[0];
            for (int j = 0; j < explanations.getRowDimension(); ++j) {
                Assertions.assertArrayEquals((double[])expected[i][j], (double[])explanations.getRow(j), (double)0.01);
            }
        }
    }

    @Test
    void testParallel() throws InterruptedException, ExecutionException {
        double[][] largeBackground = new double[100][10];
        for (int i = 0; i < 100; ++i) {
            for (int j = 0; j < 10; ++j) {
                largeBackground[i][j] = (double)i / 100.0 + (double)j;
            }
        }
        double[][] toExplainLargeBackground = new double[][]{{0.0, 1.0, -2.0, 3.5, -4.1, 5.5, -12.0, 0.8, 0.11, 15.0}};
        double[][][] expected = new double[][][]{new double[][]{{-0.495, 0.0, -4.495, 0.005, -8.595, 0.005, -18.495, -6.695, -8.385, 5.505}}};
        List<PredictionInput> background = this.createPIFromMatrix(largeBackground);
        List<PredictionInput> toExplain = this.createPIFromMatrix(toExplainLargeBackground);
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        ShapConfig skConfig = this.testConfig.withBackground(background).build();
        List predictionOutputs = (List)model.predictAsync(toExplain).get();
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        for (int i = 0; i < predictionOutputs.size(); ++i) {
            predictions.add(new SimplePrediction(toExplain.get(i), (PredictionOutput)predictionOutputs.get(i)));
        }
        ShapKernelExplainer ske = new ShapKernelExplainer(skConfig);
        CompletableFuture explanationsCF = ske.explainAsync((Prediction)predictions.get(0), model);
        ForkJoinPool executor = ForkJoinPool.commonPool();
        executor.submit(() -> {
            Saliency[] explanationSaliencies = ((ShapResults)explanationsCF.join()).getSaliencies();
            RealMatrix explanations = this.saliencyToMatrix(explanationSaliencies)[0];
            Assertions.assertArrayEquals((double[])expected[0][0], (double[])explanations.getRow(0), (double)0.01);
        });
    }

    @Test
    void testTooLargeBackground() throws InterruptedException, TimeoutException, ExecutionException {
        double[][] tooLargeBackground = new double[10][10];
        for (int i = 0; i < 10; ++i) {
            for (int j = 0; j < 10; ++j) {
                tooLargeBackground[i][j] = (double)i / 10.0 + (double)j;
            }
        }
        double[][] toExplainTooSmall = new double[][]{{0.0, 1.0, 2.0, 3.0, 4.0}};
        List<PredictionInput> background = this.createPIFromMatrix(tooLargeBackground);
        List<PredictionInput> toExplain = this.createPIFromMatrix(toExplainTooSmall);
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        ShapConfig skConfig = this.testConfig.withBackground(background).build();
        List predictionOutputs = (List)model.predictAsync(toExplain).get(5L, TimeUnit.SECONDS);
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        for (int i = 0; i < predictionOutputs.size(); ++i) {
            predictions.add(new SimplePrediction(toExplain.get(i), (PredictionOutput)predictionOutputs.get(i)));
        }
        Prediction p = (Prediction)predictions.get(0);
        ShapKernelExplainer ske = new ShapKernelExplainer(skConfig);
        Assertions.assertThrows(IllegalArgumentException.class, () -> ske.explainAsync(p, model));
    }

    @Test
    void testPredictionWrongSize() throws InterruptedException, TimeoutException, ExecutionException {
        double[][] backgroundMat = new double[5][5];
        for (int i = 0; i < 5; ++i) {
            for (int j = 0; j < 5; ++j) {
                backgroundMat[i][j] = (double)i / 5.0 + (double)j;
            }
        }
        double[][] toExplainTooSmall = new double[][]{{0.0, 1.0, 2.0, 3.0, 4.0}};
        List<PredictionInput> background = this.createPIFromMatrix(backgroundMat);
        List<PredictionInput> toExplain = this.createPIFromMatrix(toExplainTooSmall);
        PredictionProvider modelForPredictions = TestUtils.getSumSkipTwoOutputModel(1);
        PredictionProvider modelForShap = TestUtils.getSumSkipModel(1);
        ShapConfig skConfig = this.testConfig.withBackground(background).build();
        List predictionOutputs = (List)modelForPredictions.predictAsync(toExplain).get(5L, TimeUnit.SECONDS);
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        for (int i = 0; i < predictionOutputs.size(); ++i) {
            predictions.add(new SimplePrediction(toExplain.get(i), (PredictionOutput)predictionOutputs.get(i)));
        }
        Prediction p = (Prediction)predictions.get(0);
        ShapKernelExplainer ske = new ShapKernelExplainer(skConfig);
        Assertions.assertThrows(ExecutionException.class, () -> ske.explainAsync(p, modelForShap).get());
    }

    @Test
    void testStateless() throws InterruptedException, TimeoutException, ExecutionException {
        PredictionProvider model = TestUtils.getSumSkipModel(1);
        ShapConfig skConfig1 = this.testConfig.withBackground(this.createPIFromMatrix(this.backgroundNoVariance)).withNSamples(Integer.valueOf(100)).build();
        ShapConfig skConfig2 = this.testConfig.withBackground(this.createPIFromMatrix(this.backgroundRaw)).withNSamples(Integer.valueOf(35)).build();
        ShapConfig skConfig3 = ShapConfig.builder().withBackground(this.createPIFromMatrix(this.backgroundLogit)).withLink(ShapConfig.LinkType.LOGIT).withNSamples(Integer.valueOf(100)).withPC(this.pc).build();
        ShapKernelExplainer ske = new ShapKernelExplainer(skConfig1);
        for (int i = 0; i < 10; ++i) {
            this.shapTestCase(model, ske, this.toExplainOneVariance, this.oneVarianceOneOutputSHAP);
            ske.setConfig(skConfig2);
            this.shapTestCase(model, ske, this.toExplainRaw, this.multiVarianceOneOutputSHAP);
            ske.setConfig(skConfig3);
            this.shapTestCase(model, ske, this.toExplainLogit, this.logitSHAP);
            ske.setConfig(skConfig1);
        }
    }

    @ParameterizedTest
    @ValueSource(doubles={0.001, 0.1, 0.25, 0.5})
    void testErrorBounds(double noise) throws InterruptedException, ExecutionException {
        for (double interval : new double[]{0.95, 0.975, 0.99}) {
            int[] testResults = new int[600];
            for (int test = 0; test < 100; ++test) {
                PredictionProvider model = TestUtils.getNoisySumModel(this.pc.getRandom(), noise);
                ShapConfig skConfig = this.testConfig.withBackground(this.createPIFromMatrix(this.backgroundAllZeros)).withConfidence(interval).build();
                List<PredictionInput> toExplain = this.createPIFromMatrix(this.toExplainAllOnes);
                ShapKernelExplainer ske = new ShapKernelExplainer(skConfig);
                List predictionOutputs = (List)model.predictAsync(toExplain).get();
                SimplePrediction p = new SimplePrediction(toExplain.get(0), (PredictionOutput)predictionOutputs.get(0));
                Saliency[] saliencies = ((ShapResults)ske.explainAsync((Prediction)p, model).get()).getSaliencies();
                RealMatrix[] explanationsAndConfs = this.saliencyToMatrix(saliencies);
                RealMatrix explanations = explanationsAndConfs[0];
                RealMatrix confidence = explanationsAndConfs[1];
                for (int i = 0; i < explanations.getRowDimension(); ++i) {
                    for (int j = 0; j < explanations.getColumnDimension(); ++j) {
                        double conf = confidence.getEntry(i, j);
                        double exp = explanations.getEntry(i, j);
                        testResults[test * 6 + j] = exp + conf > 1.0 & 1.0 > exp - conf ? 1 : 0;
                    }
                }
            }
            double score = (double)Arrays.stream(testResults).sum() / 600.0;
            Assertions.assertEquals((double)interval, (double)score, (double)0.05);
        }
    }
}

