/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.linear.AnyMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.shap.ShapConfig;
import org.kie.kogito.explainability.local.shap.ShapDataCarrier;
import org.kie.kogito.explainability.local.shap.ShapResults;
import org.kie.kogito.explainability.local.shap.ShapStatistics;
import org.kie.kogito.explainability.local.shap.ShapSyntheticDataSample;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.utils.LarsPath;
import org.kie.kogito.explainability.utils.LassoLarsIC;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;
import org.kie.kogito.explainability.utils.RandomChoice;
import org.kie.kogito.explainability.utils.WeightedLinearRegression;
import org.kie.kogito.explainability.utils.WeightedLinearRegressionResults;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShapKernelExplainer
implements LocalExplainer<ShapResults> {
    private static final Logger LOGGER = LoggerFactory.getLogger(ShapKernelExplainer.class);
    private ShapConfig config;

    public ShapKernelExplainer(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    public void setConfig(ShapConfig shapConfig) {
        this.config = shapConfig;
    }

    private ShapDataCarrier initialize(PredictionProvider model) {
        int maxSamples;
        int rows = this.config.getBackgroundMatrix().getRowDimension();
        int cols = this.config.getBackgroundMatrix().getColumnDimension();
        if (rows > 100) {
            LOGGER.debug("Warning: Background data sets larger than 100 samples might be slow!");
        }
        CompletionStage modelNull = model.predictAsync(this.config.getBackground()).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput);
        CompletionStage outputSize = ((CompletableFuture)modelNull).thenApply(AnyMatrix::getColumnDimension);
        CompletionStage fnull = ((CompletableFuture)modelNull).thenApply(mn -> MatrixUtilsExtensions.rowSum(mn).mapDivide((double)rows));
        CompletionStage linkNull = ((CompletableFuture)fnull).thenApply(this::link);
        int numSamples = this.config.getNSamples().orElseGet(() -> 2048 + 2 * cols);
        if (cols <= 30 && (maxSamples = (int)Math.pow(2.0, cols) - 2) < numSamples) {
            numSamples = maxSamples;
        }
        ShapDataCarrier sdc = new ShapDataCarrier();
        sdc.setRows(rows);
        sdc.setCols(cols);
        sdc.setOutputSize((CompletableFuture<Integer>)outputSize);
        sdc.setModel(model);
        sdc.setFnull((CompletableFuture<RealVector>)fnull);
        sdc.setLinkNull((CompletableFuture<RealVector>)linkNull);
        sdc.setNumSamples(numSamples);
        return sdc;
    }

    private double link(double x) {
        if (this.config.getLink().equals((Object)ShapConfig.LinkType.IDENTITY)) {
            return x;
        }
        return Math.log(x / (1.0 - x));
    }

    private RealVector link(RealVector v) {
        return v.map(this::link);
    }

    private void setVaryingFeatureGroups(PredictionInput input, ShapDataCarrier sdc) {
        ArrayList<Integer> varyingFeatureGroups = new ArrayList<Integer>();
        RealVector inputVector = MatrixUtilsExtensions.vectorFromPredictionInput(this.config.getOneHotter().oneHotEncode(input, true));
        RealVector columnFeatures = MatrixUtils.createRealVector((double[])new double[sdc.getRows() + 1]);
        for (int col = 0; col < sdc.getCols(); ++col) {
            columnFeatures.setSubVector(0, this.config.getBackgroundMatrix().getColumnVector(col));
            columnFeatures.setEntry(sdc.getRows(), inputVector.getEntry(col));
            long uniques = Arrays.stream(columnFeatures.toArray()).distinct().count();
            if (uniques <= 1L) continue;
            varyingFeatureGroups.add(col);
        }
        sdc.setVaryingFeatureGroups(varyingFeatureGroups);
        sdc.setNumVarying(varyingFeatureGroups.size());
    }

    private RealVector normalizeWeightVector(RealVector v) {
        try {
            return v.mapDivide(MatrixUtilsExtensions.sum(v));
        }
        catch (MathArithmeticException e) {
            return v;
        }
    }

    private boolean addSample(PredictionInput pi, List<Integer> combination, double weight, boolean inverse, boolean fixed, ShapDataCarrier sdc) {
        boolean[] mask = new boolean[sdc.getCols()];
        if (inverse) {
            for (int i = 0; i < sdc.getNumVarying(); ++i) {
                mask[sdc.getVaryingFeatureGroups((int)i).intValue()] = true;
            }
        }
        for (Integer i : combination) {
            mask[sdc.getVaryingFeatureGroups((int)i.intValue()).intValue()] = !inverse;
        }
        int maskHash = this.hashMask(mask);
        if (sdc.getMasksUsed().containsKey(maskHash)) {
            ShapSyntheticDataSample previousSample = sdc.getSamplesAdded(sdc.getMasksUsed(maskHash));
            previousSample.incrementWeight();
            return false;
        }
        ShapSyntheticDataSample sample = new ShapSyntheticDataSample(this.config.getOneHotter().oneHotEncode(pi, true), mask, this.config.getBackgroundMatrix(), weight, fixed);
        sdc.addMask(maskHash, sdc.getSamplesAddedSize());
        sdc.addSample(sample);
        return true;
    }

    private int hashMask(boolean[] mask) {
        int maskSize = mask.length;
        int hash = 0;
        for (int i = 0; i < maskSize; ++i) {
            hash = (int)((double)hash + (mask[i] ? Math.pow(2.0, maskSize - i - 1) : 0.0));
        }
        return hash;
    }

    public static Saliency[] saliencyFromMatrix(RealMatrix m, PredictionInput pi, PredictionOutput po) {
        Saliency[] saliencies = new Saliency[m.getRowDimension()];
        for (int i = 0; i < m.getRowDimension(); ++i) {
            ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
            for (int j = 0; j < m.getColumnDimension(); ++j) {
                fis.add(new FeatureImportance(pi.getFeatures().get(j), m.getEntry(i, j)));
            }
            saliencies[i] = new Saliency(po.getOutputs().get(i), fis);
        }
        return saliencies;
    }

    public static Saliency[] saliencyFromMatrix(RealMatrix m, RealMatrix bounds, PredictionInput pi, PredictionOutput po) {
        Saliency[] saliencies = new Saliency[m.getRowDimension()];
        for (int i = 0; i < m.getRowDimension(); ++i) {
            ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
            for (int j = 0; j < m.getColumnDimension(); ++j) {
                fis.add(new FeatureImportance(pi.getFeatures().get(j), m.getEntry(i, j), bounds.getEntry(i, j)));
            }
            saliencies[i] = new Saliency(po.getOutputs().get(i), fis);
        }
        return saliencies;
    }

    private CompletableFuture<ShapResults> explain(Prediction prediction, PredictionProvider model) {
        ShapDataCarrier sdc = this.initialize(model);
        sdc.setSamplesAdded(new ArrayList<ShapSyntheticDataSample>());
        PredictionInput pi = prediction.getInput();
        PredictionOutput po = prediction.getOutput();
        if (pi.getFeatures().size() != sdc.getCols()) {
            throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", pi.getFeatures().size(), sdc.getCols()));
        }
        int cols = sdc.getCols();
        CompletionStage output = sdc.getOutputSize().thenApply(os -> {
            if (po.getOutputs().size() != os.intValue()) {
                throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", po.getOutputs().size(), os));
            }
            return MatrixUtils.createRealMatrix((double[][])new double[os.intValue()][cols]);
        });
        RealVector poVector = MatrixUtilsExtensions.vectorFromPredictionOutput(po);
        this.setVaryingFeatureGroups(pi, sdc);
        if (sdc.getNumVarying() == 0) {
            return ((CompletableFuture)((CompletableFuture)output).thenApply(o -> ShapKernelExplainer.saliencyFromMatrix(o, pi, po))).thenCombine(sdc.getFnull(), ShapResults::new);
        }
        if (sdc.getNumVarying() == 1) {
            CompletionStage diff = sdc.getLinkNull().thenApply(arg_0 -> ((RealVector)poVector).subtract(arg_0));
            return ((CompletableFuture)((CompletableFuture)output).thenCompose(arg_0 -> ShapKernelExplainer.lambda$explain$5((CompletableFuture)diff, sdc, cols, pi, po, arg_0))).thenCombine(sdc.getFnull(), ShapResults::new);
        }
        ShapStatistics shapStats = this.computeSubsetStatistics(sdc);
        this.initializeWeights(shapStats, sdc);
        this.addCompleteSubsets(shapStats, pi, sdc);
        this.renormalizeWeights(shapStats);
        this.addNonCompleteSubsets(shapStats, pi, sdc);
        CompletableFuture<RealMatrix> expectations = this.runSyntheticData(sdc);
        return ((CompletableFuture)((CompletableFuture)output).thenCompose(o -> this.solveSystem(expectations, poVector, sdc).thenApply(wo -> ShapKernelExplainer.saliencyFromMatrix(wo[0], wo[1], pi, po)))).thenCombine(sdc.getLinkNull(), ShapResults::new);
    }

    private ShapStatistics computeSubsetStatistics(ShapDataCarrier sdc) {
        int numSubsetSizes = (int)Math.ceil((double)(sdc.getNumVarying() - 1) / 2.0);
        int largestPairedSubsetSize = sdc.getNumVarying() % 2 == 1 ? numSubsetSizes : numSubsetSizes - 1;
        int[] numSubsetsAtSize = new int[numSubsetSizes + 1];
        for (int i = 1; i < numSubsetSizes + 1; ++i) {
            try {
                numSubsetsAtSize[i] = (int)CombinatoricsUtils.binomialCoefficient((int)sdc.getNumVarying(), (int)i);
                continue;
            }
            catch (MathArithmeticException e) {
                numSubsetsAtSize[i] = sdc.getNumSamples() * sdc.getNumSamples();
            }
        }
        int numSamplesRemaining = sdc.getNumSamples();
        return new ShapStatistics(numSubsetSizes, largestPairedSubsetSize, numSubsetsAtSize, numSamplesRemaining);
    }

    private void initializeWeights(ShapStatistics shapStats, ShapDataCarrier sdc) {
        double[] rawWeights = new double[shapStats.getNumSubsetSizes() + 1];
        for (int subsetSize = 1; subsetSize <= shapStats.getNumSubsetSizes(); ++subsetSize) {
            double weight = ((double)sdc.getNumVarying() - 1.0) / (double)(subsetSize * (sdc.getNumVarying() - subsetSize));
            if (subsetSize <= shapStats.getLargestPairedSubsetSize()) {
                weight *= 2.0;
            }
            rawWeights[subsetSize] = weight;
        }
        RealVector weightOfSubsetSize = this.normalizeWeightVector(MatrixUtils.createRealVector((double[])rawWeights));
        shapStats.setWeightOfSubsetSize(weightOfSubsetSize);
        shapStats.setRemainingWeights(weightOfSubsetSize.copy());
    }

    private void addCompleteSubsets(ShapStatistics shapStats, PredictionInput pi, ShapDataCarrier sdc) {
        sdc.setMasksUsed(new HashMap<Integer, Integer>());
        for (int subsetSize = 1; subsetSize < shapStats.getNumSubsetSizes() + 1; ++subsetSize) {
            int numSubsets = shapStats.getNumSubsetsAtSize()[subsetSize];
            numSubsets *= subsetSize <= shapStats.getLargestPairedSubsetSize() ? 2 : 1;
            double samplingWeight = shapStats.getRemainingWeights().getEntry(subsetSize);
            if (!((double)shapStats.getNumSamplesRemaining() * samplingWeight >= (double)numSubsets)) break;
            shapStats.incrementNumFullSubsets();
            shapStats.decreaseNumSamplesRemainingBy(numSubsets);
            RealVector remainingWeights = shapStats.getRemainingWeights();
            remainingWeights.setEntry(subsetSize, 0.0);
            shapStats.setRemainingWeights(this.normalizeWeightVector(remainingWeights));
            Iterator combinations = CombinatoricsUtils.combinationsIterator((int)sdc.getNumVarying(), (int)subsetSize);
            double individualWeight = shapStats.getWeightOfSubsetSize().getEntry(subsetSize) / (double)numSubsets;
            while (combinations.hasNext()) {
                List<Integer> combination = Arrays.stream((int[])combinations.next()).boxed().collect(Collectors.toList());
                this.addSample(pi, combination, individualWeight, false, true, sdc);
                if (subsetSize > shapStats.getLargestPairedSubsetSize()) continue;
                this.addSample(pi, combination, individualWeight, true, true, sdc);
            }
        }
    }

    private void renormalizeWeights(ShapStatistics shapStats) {
        RealVector weightOfSubsetSize = shapStats.getWeightOfSubsetSize();
        RealVector remainingWeights = weightOfSubsetSize.copy();
        RealVector divisor = MatrixUtils.createRealVector((double[])IntStream.range(0, remainingWeights.getDimension()).mapToDouble(i -> i < shapStats.getLargestPairedSubsetSize() ? 2.0 : 1.0).toArray());
        remainingWeights.ebeDivide(divisor);
        int nToGrab = shapStats.getNumSubsetSizes() - shapStats.getNumFullSubsets();
        RealVector nonFullRemainingWeights = remainingWeights.getSubVector(shapStats.getNumFullSubsets() + 1, nToGrab);
        shapStats.setFinalRemainingWeights(this.normalizeWeightVector(nonFullRemainingWeights));
    }

    private void addNonCompleteSubsets(ShapStatistics shapStats, PredictionInput pi, ShapDataCarrier sdc) {
        if (shapStats.getNumFullSubsets() < shapStats.getNumSubsetSizes()) {
            List subsetSizesRemaining = IntStream.range(shapStats.getNumFullSubsets() + 1, shapStats.getNumSubsetSizes() + 1).boxed().collect(Collectors.toList());
            List<Double> subsetSizeWeights = Arrays.stream(shapStats.getFinalRemainingWeights().toArray()).boxed().collect(Collectors.toList());
            RandomChoice subsetSampler = new RandomChoice(subsetSizesRemaining, subsetSizeWeights);
            List sizeSamples = subsetSampler.sample(shapStats.getNumSamplesRemaining() * 4, this.config.getPC().getRandom());
            List maskSizes = IntStream.range(0, sdc.getNumVarying()).boxed().collect(Collectors.toList());
            int sampleIdx = 0;
            while (shapStats.getNumSamplesRemaining() > 0) {
                if (sampleIdx >= sizeSamples.size()) {
                    sizeSamples = subsetSampler.sample(shapStats.getNumSamplesRemaining() * 4, this.config.getPC().getRandom());
                    sampleIdx = 0;
                }
                int subsetSize = (Integer)sizeSamples.get(sampleIdx);
                ++sampleIdx;
                Collections.shuffle(maskSizes);
                List<Integer> maskIdxs = maskSizes.subList(0, subsetSize);
                if (this.addSample(pi, maskIdxs, 1.0, false, false, sdc)) {
                    shapStats.decreaseNumSamplesRemainingBy(1);
                }
                if (shapStats.getNumSamplesRemaining() <= 0 || subsetSize > shapStats.getLargestPairedSubsetSize() || !this.addSample(pi, maskIdxs, 1.0, true, false, sdc)) continue;
                shapStats.decreaseNumSamplesRemainingBy(1);
            }
            this.normalizeSampleWeights(shapStats, sdc);
        }
    }

    private void normalizeSampleWeights(ShapStatistics shapStats, ShapDataCarrier sdc) {
        int i;
        double nonFullWeight = MatrixUtilsExtensions.sum(shapStats.getWeightOfSubsetSize().getSubVector(shapStats.getNumFullSubsets(), shapStats.getNumSubsetSizes() - shapStats.getNumFullSubsets()));
        double nonFixedWeight = 0.0;
        for (i = 0; i < sdc.getSamplesAddedSize(); ++i) {
            if (sdc.getSamplesAdded(i).isFixed()) continue;
            nonFixedWeight += sdc.getSamplesAdded(i).getWeight();
        }
        for (i = 0; i < sdc.getSamplesAddedSize(); ++i) {
            ShapSyntheticDataSample sample = sdc.getSamplesAdded(i);
            if (sample.isFixed() || nonFixedWeight == 0.0) continue;
            sample.setWeight(sample.getWeight() * nonFullWeight / nonFixedWeight);
        }
    }

    private CompletableFuture<RealMatrix> runSyntheticData(ShapDataCarrier sdc) {
        if (this.config.getBatchSize() > 1) {
            int batchSize = this.config.getBatchSize();
            return sdc.getLinkNull().thenCompose(ln -> sdc.getOutputSize().thenCompose(os -> {
                CompletionStage<Object> expectations = CompletableFuture.supplyAsync(() -> MatrixUtils.createRealMatrix((double[][])new double[sdc.getSamplesAddedSize().intValue()][os.intValue()]), this.config.getExecutor());
                for (int i = 0; i < sdc.getSamplesAddedSize(); i += batchSize) {
                    int finalI = i;
                    List batch = IntStream.range(i, Math.min(sdc.getSamplesAddedSize(), i + batchSize)).mapToObj(b -> sdc.getSamplesAdded(b).getSyntheticData()).collect(ArrayList::new, List::addAll, List::addAll);
                    expectations = ((CompletableFuture)((CompletableFuture)sdc.getModel().predictAsync(this.config.getOneHotter().oneHotDecode(batch, true)).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput)).thenApply(ops -> MatrixUtilsExtensions.batchRowMean(ops, sdc.getRows()))).thenCombine(expectations, (expSlices, exps) -> {
                        IntStream.range(0, expSlices.getRowDimension()).forEach(sliceIdx -> exps.setRowVector(finalI + sliceIdx, expSlices.getRowVector(sliceIdx).map(this::link).subtract(ln)));
                        return exps;
                    });
                }
                return expectations;
            }));
        }
        return sdc.getLinkNull().thenCompose(ln -> sdc.getOutputSize().thenCompose(os -> {
            HashMap<Integer, CompletionStage> expectationSlices = new HashMap<Integer, CompletionStage>();
            for (int i = 0; i < sdc.getSamplesAddedSize(); ++i) {
                List<PredictionInput> pis = this.config.getOneHotter().oneHotDecode(sdc.getSamplesAdded(i).getSyntheticData(), true);
                expectationSlices.put(i, ((CompletableFuture)((CompletableFuture)((CompletableFuture)sdc.getModel().predictAsync(pis).thenApply(MatrixUtilsExtensions::matrixFromPredictionOutput)).thenApply(posMatrix -> MatrixUtilsExtensions.rowSum(posMatrix).mapDivide((double)posMatrix.getRowDimension()))).thenApply(this::link)).thenApply(x -> x.subtract(ln)));
            }
            CompletableFuture[] expectations = new CompletableFuture[]{CompletableFuture.supplyAsync(() -> MatrixUtils.createRealMatrix((double[][])new double[sdc.getSamplesAddedSize().intValue()][os.intValue()]), this.config.getExecutor())};
            expectationSlices.forEach((idx, slice) -> {
                expectations[0] = expectations[0].thenCompose(e -> slice.thenApply(s -> {
                    e.setRowVector(idx.intValue(), s);
                    return e;
                }));
            });
            return expectations[0];
        }));
    }

    private List<Integer> getRegularizationIndexes(RealMatrix augX, RealVector augY) {
        List<Integer> nonzeros = List.of();
        switch (this.config.getRegularizerType()) {
            case AUTO: 
            case AIC: {
                nonzeros = MatrixUtilsExtensions.nonzero(LassoLarsIC.fit(augX, augY, LassoLarsIC.Criterion.AIC).getCoefs());
                break;
            }
            case BIC: {
                nonzeros = MatrixUtilsExtensions.nonzero(LassoLarsIC.fit(augX, augY, LassoLarsIC.Criterion.BIC).getCoefs());
                break;
            }
            case TOP_N_FEATURES: {
                nonzeros = LarsPath.fit(augX, augY, this.config.getNRegularizationFeatures(), false).getActive();
                break;
            }
            case NONE: {
                throw new IllegalArgumentException("RegularizerType=NONE will never be able enter the switch statement");
            }
        }
        return nonzeros;
    }

    private RealVector[] solve(RealMatrix expectations, int output, RealVector poVector, RealVector fnull, ShapDataCarrier sdc) {
        List<Integer> nonzeros;
        boolean specificRegularize;
        RealMatrix xs = MatrixUtils.createRealMatrix((double[][])new double[sdc.getSamplesAddedSize().intValue()][sdc.getCols()]);
        RealVector ws = MatrixUtils.createRealVector((double[])new double[sdc.getSamplesAddedSize().intValue()]);
        RealVector ys = MatrixUtils.createRealVector((double[])new double[sdc.getSamplesAddedSize().intValue()]);
        for (int i = 0; i < sdc.getSamplesAddedSize(); ++i) {
            for (int j = 0; j < sdc.getCols(); ++j) {
                xs.setEntry(i, j, sdc.getSamplesAdded(i).getMask()[j] ? 1.0 : 0.0);
            }
            ys.setEntry(i, expectations.getEntry(i, output));
            ws.setEntry(i, sdc.getSamplesAdded(i).getWeight());
        }
        double sampleFraction = (double)sdc.getSamplesAddedSize().intValue() / Math.pow(2.0, sdc.getCols());
        double outputChange = this.link(poVector.getEntry(output)) - this.link(fnull.getEntry(output));
        boolean autoRegularize = sampleFraction < 0.2 && this.config.getRegularizerType() == ShapConfig.RegularizerType.AUTO;
        boolean bl = specificRegularize = this.config.getRegularizerType() != ShapConfig.RegularizerType.NONE && this.config.getRegularizerType() != ShapConfig.RegularizerType.AUTO;
        if (autoRegularize || specificRegularize) {
            RealVector maskSum = MatrixUtilsExtensions.colSum(xs);
            RealVector augWeights = MatrixUtils.createRealVector((double[])new double[ws.getDimension() * 2]);
            augWeights.setSubVector(0, ws.ebeMultiply(maskSum.map(x -> (double)sdc.getNumVarying() - x)));
            augWeights.setSubVector(ws.getDimension(), ws.ebeMultiply(maskSum));
            RealVector sqrtAugWeights = augWeights.map(Math::sqrt);
            RealVector augYs = MatrixUtils.createRealVector((double[])new double[ys.getDimension() * 2]);
            augYs.setSubVector(0, ys);
            augYs.setSubVector(ys.getDimension(), ys.mapSubtract(outputChange));
            augYs = augYs.ebeMultiply(sqrtAugWeights);
            RealMatrix augXsRaw = MatrixUtils.createRealMatrix((int)(xs.getRowDimension() * 2), (int)xs.getColumnDimension());
            augXsRaw.setSubMatrix(xs.getData(), 0, 0);
            augXsRaw.setSubMatrix(MatrixUtilsExtensions.map(xs, x -> x - 1.0).getData(), xs.getRowDimension(), 0);
            RealMatrix augXs = MatrixUtilsExtensions.vectorRowProduct(augXsRaw.transpose(), sqrtAugWeights).transpose();
            nonzeros = this.getRegularizationIndexes(augXs, augYs);
        } else {
            nonzeros = sdc.getVaryingFeatureGroups();
        }
        int dropIdx = nonzeros.get(nonzeros.size() - 1);
        RealVector dropMask = xs.getColumnVector(dropIdx);
        RealVector dropEffect = dropMask.mapMultiply(outputChange);
        RealVector adjY = ys.subtract(dropEffect);
        List<Integer> allNZButLast = nonzeros.subList(0, nonzeros.size() - 1);
        RealMatrix xsAdj = MatrixUtilsExtensions.vectorDifference(MatrixUtilsExtensions.getCols(xs, allNZButLast), dropMask, MatrixUtilsExtensions.Axis.COLUMN);
        return this.runWLRR(xsAdj, adjY, ws, outputChange, dropIdx, nonzeros, sdc);
    }

    private CompletableFuture<RealMatrix[]> solveSystem(CompletableFuture<RealMatrix> expectations, RealVector poVector, ShapDataCarrier sdc) {
        return expectations.thenCompose(exps -> sdc.getFnull().thenCompose(fn -> sdc.getOutputSize().thenCompose(os -> {
            HashMap<Integer, CompletableFuture> shapSlices = new HashMap<Integer, CompletableFuture>();
            for (int output = 0; output < os; ++output) {
                int finalOutput = output;
                shapSlices.put(output, CompletableFuture.supplyAsync(() -> this.solve((RealMatrix)exps, finalOutput, poVector, (RealVector)fn, sdc), this.config.getExecutor()));
            }
            RealMatrix outputMatrix = MatrixUtils.createRealMatrix((double[][])new double[os.intValue()][sdc.getCols()]);
            CompletableFuture[] shapVals = new CompletableFuture[]{CompletableFuture.supplyAsync(() -> new RealMatrix[]{outputMatrix.copy(), outputMatrix.copy()}, this.config.getExecutor())};
            shapSlices.forEach((idx, slice) -> {
                shapVals[0] = shapVals[0].thenCompose(e -> slice.thenApply(s -> {
                    e[0].setRowVector(idx.intValue(), s[0]);
                    e[1].setRowVector(idx.intValue(), s[1]);
                    return e;
                }));
            });
            return shapVals[0];
        })));
    }

    private RealVector[] runWLRR(RealMatrix maskDiff, RealVector adjY, RealVector ws, double outputChange, int dropIdx, List<Integer> nonzeros, ShapDataCarrier sdc) {
        WeightedLinearRegressionResults wlrr = WeightedLinearRegression.fit(maskDiff, adjY, ws, false);
        RealVector coeffs = wlrr.getCoefficients();
        RealVector bounds = wlrr.getConf(1.0 - this.config.getConfidence());
        int usedCoefs = 0;
        RealVector shapSlice = MatrixUtils.createRealVector((double[])new double[sdc.getCols()]);
        RealVector boundsReg = shapSlice.copy();
        for (int idx : nonzeros) {
            if (idx == dropIdx) continue;
            shapSlice.setEntry(idx, coeffs.getEntry(usedCoefs));
            boundsReg.setEntry(idx, bounds.getEntry(usedCoefs));
            ++usedCoefs;
        }
        shapSlice.setEntry(dropIdx, outputChange - MatrixUtilsExtensions.sum(coeffs));
        boundsReg.setEntry(dropIdx, Math.sqrt(MatrixUtilsExtensions.sum(bounds.map(x -> x * x))));
        RealVector[] wlrrOutput = new RealVector[]{shapSlice, boundsReg};
        return wlrrOutput;
    }

    @Override
    public CompletableFuture<ShapResults> explainAsync(Prediction prediction, PredictionProvider model) {
        return this.explainAsync(prediction, model, (Consumer<ShapResults>)null);
    }

    @Override
    public CompletableFuture<ShapResults> explainAsync(Prediction prediction, PredictionProvider model, Consumer<ShapResults> intermediateResultsConsumer) {
        return this.explain(prediction, model);
    }

    private static /* synthetic */ CompletionStage lambda$explain$5(CompletableFuture diff, ShapDataCarrier sdc, int cols, PredictionInput pi, PredictionOutput po, RealMatrix o) {
        return diff.thenCombine(sdc.getOutputSize(), (df, os) -> {
            RealMatrix out = MatrixUtils.createRealMatrix((double[][])new double[os.intValue()][cols]);
            for (int i = 0; i < os; ++i) {
                out.setEntry(i, sdc.getVaryingFeatureGroups(0).intValue(), df.getEntry(i));
            }
            return ShapKernelExplainer.saliencyFromMatrix(out, pi, po);
        });
    }
}

