/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.shap.ShapConfig;
import org.kie.kogito.explainability.local.shap.ShapDataCarrier;
import org.kie.kogito.explainability.local.shap.ShapResults;
import org.kie.kogito.explainability.local.shap.ShapStatistics;
import org.kie.kogito.explainability.local.shap.ShapSyntheticDataSample;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;
import org.kie.kogito.explainability.utils.RandomChoice;
import org.kie.kogito.explainability.utils.WeightedLinearRegression;
import org.kie.kogito.explainability.utils.WeightedLinearRegressionResults;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShapKernelExplainer
implements LocalExplainer<ShapResults> {
    private static final Logger LOGGER = LoggerFactory.getLogger(ShapKernelExplainer.class);
    private ShapConfig config;

    public ShapKernelExplainer(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    public void setConfig(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    private ShapDataCarrier initialize(PredictionProvider model) {
        int maxSamples;
        int[] shape = MatrixUtilsExtensions.getShape(this.config.getBackgroundMatrix());
        int rows = shape[0];
        int cols = shape[1];
        if (rows > 100) {
            LOGGER.debug("Warning: Background data sets larger than 100 samples might be slow!");
        }
        CompletionStage modelNull = model.predictAsync(this.config.getBackground()).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput);
        CompletionStage outputSize = ((CompletableFuture)modelNull).thenApply(mn -> MatrixUtilsExtensions.getShape(mn)[1]);
        CompletionStage fnull = ((CompletableFuture)modelNull).thenApply(mn -> MatrixUtilsExtensions.sum(MatrixUtilsExtensions.matrixMultiply(mn, 1.0 / (double)rows), MatrixUtilsExtensions.Axis.ROW));
        CompletionStage linkNull = ((CompletableFuture)fnull).thenApply(fn -> MatrixUtilsExtensions.rowVector(this.link((double[])fn)));
        int numSamples = this.config.getNSamples().orElseGet(() -> 2048 + 2 * cols);
        if (cols <= 30 && (maxSamples = (int)Math.pow(2.0, cols) - 2) < numSamples) {
            numSamples = maxSamples;
        }
        ShapDataCarrier sdc = new ShapDataCarrier();
        sdc.setRows(rows);
        sdc.setCols(cols);
        sdc.setOutputSize((CompletableFuture<Integer>)outputSize);
        sdc.setModel(model);
        sdc.setFnull((CompletableFuture<double[]>)fnull);
        sdc.setLinkNull((CompletableFuture<double[][]>)linkNull);
        sdc.setNumSamples(numSamples);
        return sdc;
    }

    private double link(double x) {
        if (this.config.getLink().equals((Object)ShapConfig.LinkType.IDENTITY)) {
            return x;
        }
        return Math.log(x / (1.0 - x));
    }

    private double[] link(double[] v) {
        return Arrays.stream(v).map(this::link).toArray();
    }

    private void setVaryingFeatureGroups(PredictionInput input, ShapDataCarrier sdc) {
        ArrayList<Integer> varyingFeatureGroups = new ArrayList<Integer>();
        double[] inputVector = MatrixUtilsExtensions.matrixFromPredictionInput(input)[0];
        double[] columnFeatures = new double[sdc.getRows() + 1];
        for (int col = 0; col < sdc.getCols(); ++col) {
            System.arraycopy(MatrixUtilsExtensions.getCol(this.config.getBackgroundMatrix(), col), 0, columnFeatures, 0, sdc.getRows());
            columnFeatures[sdc.getRows()] = inputVector[col];
            long uniques = Arrays.stream(columnFeatures).distinct().count();
            if (uniques <= 1L) continue;
            varyingFeatureGroups.add(col);
        }
        sdc.setVaryingFeatureGroups(varyingFeatureGroups);
        sdc.setNumVarying(varyingFeatureGroups.size());
    }

    private double[] normalizeWeightVector(double[] v) {
        double[][] expanded = MatrixUtilsExtensions.rowVector(v);
        double sum = MatrixUtilsExtensions.sum(expanded, MatrixUtilsExtensions.Axis.COLUMN)[0];
        if (sum == 0.0) {
            return v;
        }
        return MatrixUtilsExtensions.matrixMultiply(expanded, 1.0 / sum)[0];
    }

    private void addSample(PredictionInput pi, List<Integer> combination, double weight, boolean inverse, boolean fixed, ShapDataCarrier sdc) {
        int i;
        boolean[] mask = new boolean[sdc.getCols()];
        if (inverse) {
            for (i = 0; i < sdc.getNumVarying(); ++i) {
                mask[sdc.getVaryingFeatureGroups((int)i).intValue()] = true;
            }
        }
        for (i = 0; i < combination.size(); ++i) {
            mask[sdc.getVaryingFeatureGroups((int)combination.get((int)i).intValue()).intValue()] = !inverse;
        }
        int maskHash = this.hashMask(mask);
        if (sdc.getMasksUsed().containsKey(maskHash)) {
            ShapSyntheticDataSample previousSample = sdc.getSamplesAdded(sdc.getMasksUsed(maskHash));
            previousSample.incrementWeight();
        } else {
            ShapSyntheticDataSample sample = new ShapSyntheticDataSample(pi, mask, this.config.getBackgroundMatrix(), weight, fixed);
            sdc.addMask(maskHash, sdc.getSamplesAddedSize());
            sdc.addSample(sample);
        }
    }

    private int hashMask(boolean[] mask) {
        int maskSize = mask.length;
        int hash = 0;
        for (int i = 0; i < maskSize; ++i) {
            hash = (int)((double)hash + (mask[i] ? Math.pow(2.0, maskSize - i - 1) : 0.0));
        }
        return hash;
    }

    public static Saliency[] saliencyFromMatrix(double[][] m, PredictionInput pi, PredictionOutput po) {
        int[] shape = MatrixUtilsExtensions.getShape(m);
        Saliency[] saliencies = new Saliency[shape[0]];
        for (int i = 0; i < shape[0]; ++i) {
            ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
            for (int j = 0; j < shape[1]; ++j) {
                fis.add(new FeatureImportance(pi.getFeatures().get(j), m[i][j]));
            }
            saliencies[i] = new Saliency(po.getOutputs().get(i), fis);
        }
        return saliencies;
    }

    public static Saliency[] saliencyFromMatrix(double[][] m, double[][] bounds, PredictionInput pi, PredictionOutput po) {
        int[] shape = MatrixUtilsExtensions.getShape(m);
        Saliency[] saliencies = new Saliency[shape[0]];
        for (int i = 0; i < shape[0]; ++i) {
            ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
            for (int j = 0; j < shape[1]; ++j) {
                fis.add(new FeatureImportance(pi.getFeatures().get(j), m[i][j], bounds[i][j]));
            }
            saliencies[i] = new Saliency(po.getOutputs().get(i), fis);
        }
        return saliencies;
    }

    private CompletableFuture<ShapResults> explain(Prediction prediction, PredictionProvider model) {
        ShapDataCarrier sdc = this.initialize(model);
        sdc.setSamplesAdded(new ArrayList<ShapSyntheticDataSample>());
        PredictionInput pi = prediction.getInput();
        PredictionOutput po = prediction.getOutput();
        if (pi.getFeatures().size() != sdc.getCols()) {
            throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", pi.getFeatures().size(), sdc.getCols()));
        }
        int cols = sdc.getCols();
        CompletionStage output = sdc.getOutputSize().thenApply(os -> {
            if (po.getOutputs().size() != os.intValue()) {
                throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", po.getOutputs().size(), os));
            }
            return new double[os.intValue()][cols];
        });
        double[][] poMatrix = MatrixUtilsExtensions.matrixFromPredictionOutput(po);
        this.setVaryingFeatureGroups(pi, sdc);
        if (sdc.getNumVarying() == 0) {
            return ((CompletableFuture)((CompletableFuture)output).thenApply(o -> ShapKernelExplainer.saliencyFromMatrix(o, pi, po))).thenCombine(sdc.getFnull(), ShapResults::new);
        }
        if (sdc.getNumVarying() == 1) {
            CompletionStage diff = sdc.getLinkNull().thenApply(ln -> MatrixUtilsExtensions.matrixDifference(poMatrix, ln)[0]);
            return ((CompletableFuture)((CompletableFuture)output).thenCompose(arg_0 -> ShapKernelExplainer.lambda$explain$8((CompletableFuture)diff, sdc, cols, pi, po, arg_0))).thenCombine(sdc.getFnull(), ShapResults::new);
        }
        ShapStatistics shapStats = this.computeSubsetStatistics(sdc);
        this.initializeWeights(shapStats, sdc);
        this.addCompleteSubsets(shapStats, pi, sdc);
        this.renormalizeWeights(shapStats);
        this.addNonCompleteSubsets(shapStats, pi, sdc);
        CompletableFuture<double[][]> expectations = this.runSyntheticData(sdc);
        return ((CompletableFuture)((CompletableFuture)output).thenCompose(o -> this.solveSystem(expectations, poMatrix[0], sdc).thenApply(wo -> ShapKernelExplainer.saliencyFromMatrix(wo[0], wo[1], pi, po)))).thenCombine(sdc.getFnull(), ShapResults::new);
    }

    private ShapStatistics computeSubsetStatistics(ShapDataCarrier sdc) {
        int numSubsetSizes = (int)Math.ceil((double)(sdc.getNumVarying() - 1) / 2.0);
        int largestPairedSubsetSize = sdc.getNumVarying() % 2 == 1 ? numSubsetSizes : numSubsetSizes - 1;
        int[] numSubsetsAtSize = new int[numSubsetSizes + 1];
        for (int i = 1; i < numSubsetSizes + 1; ++i) {
            try {
                numSubsetsAtSize[i] = (int)CombinatoricsUtils.binomialCoefficient((int)sdc.getNumVarying(), (int)i);
                continue;
            }
            catch (MathArithmeticException e) {
                numSubsetsAtSize[i] = sdc.getNumSamples() * sdc.getNumSamples();
            }
        }
        int numSamplesRemaining = sdc.getNumSamples();
        return new ShapStatistics(numSubsetSizes, largestPairedSubsetSize, numSubsetsAtSize, numSamplesRemaining);
    }

    private void initializeWeights(ShapStatistics shapStats, ShapDataCarrier sdc) {
        double[] rawWeights = new double[shapStats.getNumSubsetSizes() + 1];
        for (int subsetSize = 1; subsetSize <= shapStats.getNumSubsetSizes(); ++subsetSize) {
            double weight = ((double)sdc.getNumVarying() - 1.0) / (double)(subsetSize * (sdc.getNumVarying() - subsetSize));
            if (subsetSize <= shapStats.getLargestPairedSubsetSize()) {
                weight *= 2.0;
            }
            rawWeights[subsetSize] = weight;
        }
        double[] weightOfSubsetSize = this.normalizeWeightVector(rawWeights);
        shapStats.setWeightOfSubsetSize(weightOfSubsetSize);
        shapStats.setRemainingWeights(Arrays.copyOf(weightOfSubsetSize, weightOfSubsetSize.length));
    }

    private void addCompleteSubsets(ShapStatistics shapStats, PredictionInput pi, ShapDataCarrier sdc) {
        sdc.setMasksUsed(new HashMap<Integer, Integer>());
        for (int subsetSize = 1; subsetSize < shapStats.getNumSubsetSizes() + 1; ++subsetSize) {
            int numSubsets = shapStats.getNumSubsetsAtSize()[subsetSize];
            numSubsets *= subsetSize <= shapStats.getLargestPairedSubsetSize() ? 2 : 1;
            double samplingWeight = shapStats.getRemainingWeights()[subsetSize];
            if (!((double)shapStats.getNumSamplesRemaining() * samplingWeight >= (double)numSubsets)) break;
            shapStats.incrementNumFullSubsets();
            shapStats.decreaseNumSamplesRemainingBy(numSubsets);
            double[] remainingWeights = shapStats.getRemainingWeights();
            remainingWeights[subsetSize] = 0.0;
            shapStats.setRemainingWeights(this.normalizeWeightVector(remainingWeights));
            Iterator combinations = CombinatoricsUtils.combinationsIterator((int)sdc.getNumVarying(), (int)subsetSize);
            double individualWeight = shapStats.getWeightOfSubsetSize()[subsetSize] / (double)numSubsets;
            while (combinations.hasNext()) {
                List<Integer> combination = Arrays.stream((int[])combinations.next()).boxed().collect(Collectors.toList());
                this.addSample(pi, combination, individualWeight, false, true, sdc);
                if (subsetSize > shapStats.getLargestPairedSubsetSize()) continue;
                this.addSample(pi, combination, individualWeight, true, true, sdc);
            }
        }
    }

    private void renormalizeWeights(ShapStatistics shapStats) {
        double[] weightOfSubsetSize = shapStats.getWeightOfSubsetSize();
        double[] remainingWeights = Arrays.copyOf(weightOfSubsetSize, weightOfSubsetSize.length);
        for (int i = 0; i < remainingWeights.length; ++i) {
            if (i >= shapStats.getLargestPairedSubsetSize()) continue;
            int n = i;
            remainingWeights[n] = remainingWeights[n] / 2.0;
        }
        double[] nonFullRemainingWeights = Arrays.copyOfRange(remainingWeights, shapStats.getNumFullSubsets() + 1, shapStats.getNumSubsetSizes() + 1);
        shapStats.setFinalRemainingWeights(this.normalizeWeightVector(nonFullRemainingWeights));
    }

    private void addNonCompleteSubsets(ShapStatistics shapStats, PredictionInput pi, ShapDataCarrier sdc) {
        if (shapStats.getNumFullSubsets() < shapStats.getNumSubsetSizes()) {
            List subsetSizesRemaining = IntStream.range(shapStats.getNumFullSubsets() + 1, shapStats.getNumSubsetSizes() + 1).boxed().collect(Collectors.toList());
            List<Double> subsetSizeWeights = Arrays.stream(shapStats.getFinalRemainingWeights()).boxed().collect(Collectors.toList());
            RandomChoice subsetSampler = new RandomChoice(subsetSizesRemaining, subsetSizeWeights);
            List sizeSamples = subsetSampler.sample(shapStats.getNumSamplesRemaining() * 4, this.config.getPC().getRandom());
            List maskSizes = IntStream.range(0, sdc.getNumVarying()).boxed().collect(Collectors.toList());
            int sampleIdx = 0;
            while (shapStats.getNumSamplesRemaining() > 0) {
                int subsetSize = (Integer)sizeSamples.get(sampleIdx);
                ++sampleIdx;
                Collections.shuffle(maskSizes);
                List<Integer> maskIdxs = maskSizes.subList(0, subsetSize);
                this.addSample(pi, maskIdxs, 1.0, false, false, sdc);
                shapStats.decreaseNumSamplesRemainingBy(1);
                if (shapStats.getNumSamplesRemaining() <= 0 || subsetSize > shapStats.getLargestPairedSubsetSize()) continue;
                this.addSample(pi, maskIdxs, 1.0, true, false, sdc);
                shapStats.decreaseNumSamplesRemainingBy(1);
            }
            this.normalizeSampleWeights(shapStats, sdc);
        }
    }

    private void normalizeSampleWeights(ShapStatistics shapStats, ShapDataCarrier sdc) {
        int i;
        double nonFullWeight = 0.0;
        for (int i2 = shapStats.getNumFullSubsets() + 1; i2 < shapStats.getNumSubsetSizes() + 1; ++i2) {
            nonFullWeight += shapStats.getWeightOfSubsetSize()[i2];
        }
        double nonFixedWeight = 0.0;
        for (i = 0; i < sdc.getSamplesAddedSize(); ++i) {
            if (sdc.getSamplesAdded(i).isFixed()) continue;
            nonFixedWeight += sdc.getSamplesAdded(i).getWeight();
        }
        for (i = 0; i < sdc.getSamplesAddedSize(); ++i) {
            ShapSyntheticDataSample sample = sdc.getSamplesAdded(i);
            if (sample.isFixed() || nonFixedWeight == 0.0) continue;
            sample.setWeight(sample.getWeight() * nonFullWeight / nonFixedWeight);
        }
    }

    private CompletableFuture<double[][]> runSyntheticData(ShapDataCarrier sdc) {
        return sdc.getLinkNull().thenCompose(ln -> sdc.getOutputSize().thenCompose(os -> {
            HashMap<Integer, CompletionStage> expectationSlices = new HashMap<Integer, CompletionStage>();
            for (int i = 0; i < sdc.getSamplesAddedSize(); ++i) {
                List<PredictionInput> pis = sdc.getSamplesAdded(i).getSyntheticData();
                expectationSlices.put(i, ((CompletableFuture)((CompletableFuture)((CompletableFuture)sdc.getModel().predictAsync(pis).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput)).thenApply(posMatrix -> MatrixUtilsExtensions.sum(MatrixUtilsExtensions.matrixMultiply(posMatrix, 1.0 / (double)sdc.getRows()), MatrixUtilsExtensions.Axis.ROW))).thenApply(this::link)).thenApply(x -> MatrixUtilsExtensions.matrixDifference(MatrixUtilsExtensions.rowVector(x), ln)[0]));
            }
            CompletableFuture[] expectations = new CompletableFuture[]{CompletableFuture.supplyAsync(() -> new double[sdc.getSamplesAddedSize().intValue()][os.intValue()], this.config.getExecutor())};
            expectationSlices.forEach((idx, slice) -> {
                expectations[0] = expectations[0].thenCompose(e -> slice.thenApply(s -> {
                    e[idx.intValue()] = s;
                    return e;
                }));
            });
            return expectations[0];
        }));
    }

    private double[][] solve(double[][] expectations, int output, double[] poMatrix, double[] fnull, int dropIdx, ShapDataCarrier sdc) {
        double[][] xs = new double[sdc.getSamplesAddedSize().intValue()][sdc.getCols()];
        double[] ws = new double[sdc.getSamplesAddedSize().intValue()];
        double[] ys = new double[sdc.getSamplesAddedSize().intValue()];
        for (int i = 0; i < sdc.getSamplesAddedSize(); ++i) {
            for (int j = 0; j < sdc.getCols(); ++j) {
                xs[i][j] = sdc.getSamplesAdded(i).getMask()[j] ? 1.0 : 0.0;
            }
            ys[i] = expectations[i][output];
            ws[i] = sdc.getSamplesAdded(i).getWeight();
        }
        double outputChange = this.link(poMatrix[output]) - this.link(fnull[output]);
        double[][] dropMask = MatrixUtilsExtensions.rowVector(MatrixUtilsExtensions.getCol(xs, dropIdx));
        double[][] dropEffect = MatrixUtilsExtensions.matrixMultiply(dropMask, outputChange);
        double[] adjY = MatrixUtilsExtensions.matrixDifference(MatrixUtilsExtensions.rowVector(ys), dropEffect)[0];
        ArrayList<Integer> included = new ArrayList<Integer>();
        sdc.getVaryingFeatureGroups().forEach(v -> {
            if (v != dropIdx) {
                included.add((Integer)v);
            }
        });
        double[][] includeMask = MatrixUtilsExtensions.transpose(MatrixUtilsExtensions.getCols(xs, included));
        double[][] maskDiff = MatrixUtilsExtensions.transpose(MatrixUtilsExtensions.matrixRowDifference(includeMask, dropMask[0]));
        return this.runWLRR(maskDiff, adjY, ws, outputChange, dropIdx, sdc);
    }

    private CompletableFuture<double[][][]> solveSystem(CompletableFuture<double[][]> expectations, double[] poMatrix, ShapDataCarrier sdc) {
        int dropIdx = sdc.getVaryingFeatureGroups(sdc.getVaryingFeatureGroups().size() - 1);
        return expectations.thenCompose(exps -> sdc.getFnull().thenCompose(fn -> sdc.getOutputSize().thenCompose(os -> {
            HashMap<Integer, CompletableFuture> shapSlices = new HashMap<Integer, CompletableFuture>();
            for (int output = 0; output < os; ++output) {
                int finalOutput = output;
                shapSlices.put(output, CompletableFuture.supplyAsync(() -> this.solve((double[][])exps, finalOutput, poMatrix, (double[])fn, dropIdx, sdc), this.config.getExecutor()));
            }
            CompletableFuture[] shapVals = new CompletableFuture[]{CompletableFuture.supplyAsync(() -> new double[2][os.intValue()][sdc.getCols()], this.config.getExecutor())};
            shapSlices.forEach((idx, slice) -> {
                shapVals[0] = shapVals[0].thenCompose(e -> slice.thenApply(s -> {
                    e[0][idx.intValue()] = s[0];
                    e[1][idx.intValue()] = s[1];
                    return e;
                }));
            });
            return shapVals[0];
        })));
    }

    private double[][] runWLRR(double[][] maskDiff, double[] adjY, double[] ws, double outputChange, int dropIdx, ShapDataCarrier sdc) {
        WeightedLinearRegressionResults wlrr = WeightedLinearRegression.fit(maskDiff, adjY, ws, false, this.config.getPC().getRandom());
        double[] coeffs = wlrr.getCoefficients();
        double[] bounds = wlrr.getConf(1.0 - this.config.getConfidence());
        int usedCoefs = 0;
        double[] shapSlice = new double[sdc.getCols()];
        double[] boundsReg = new double[sdc.getCols()];
        for (int i = 0; i < sdc.getVaryingFeatureGroups().size(); ++i) {
            int idx = sdc.getVaryingFeatureGroups(i);
            if (idx == dropIdx) continue;
            shapSlice[idx] = coeffs[usedCoefs];
            boundsReg[idx] = bounds[usedCoefs];
            ++usedCoefs;
        }
        shapSlice[dropIdx] = outputChange - Arrays.stream(coeffs).sum();
        boundsReg[dropIdx] = Math.sqrt(Arrays.stream(bounds).map(x -> x * x).sum());
        double[][] wlrrOutput = new double[2][sdc.getCols()];
        wlrrOutput[0] = shapSlice;
        wlrrOutput[1] = boundsReg;
        return wlrrOutput;
    }

    @Override
    public CompletableFuture<ShapResults> explainAsync(Prediction prediction, PredictionProvider model) {
        return this.explainAsync(prediction, model, (Consumer<ShapResults>)null);
    }

    @Override
    public CompletableFuture<ShapResults> explainAsync(Prediction prediction, PredictionProvider model, Consumer<ShapResults> intermediateResultsConsumer) {
        return this.explain(prediction, model);
    }

    private static /* synthetic */ CompletionStage lambda$explain$8(CompletableFuture diff, ShapDataCarrier sdc, int cols, PredictionInput pi, PredictionOutput po, double[][] o) {
        return diff.thenCombine(sdc.getOutputSize(), (df, os) -> {
            double[][] out = new double[os.intValue()][cols];
            for (int i = 0; i < os; ++i) {
                out[i][sdc.getVaryingFeatureGroups((int)0).intValue()] = df[i];
            }
            return ShapKernelExplainer.saliencyFromMatrix(out, pi, po);
        });
    }
}

