/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.shap;

import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealVector;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.local.shap.ShapStatistics;

class ShapStatisticsTest {
    int numSubsetSizes = 5;
    int largestPairedSubsetSize = 3;
    int[] numSubsetsAtSize = new int[]{1, 3, 5, 3, 1};
    int numSamplesRemaining = 10;
    RealVector weightOfSubsetSize = MatrixUtils.createRealVector((double[])new double[]{0.1, 0.3, 0.5, 0.3, 0.1});
    RealVector finalRemainingWeights = MatrixUtils.createRealVector((double[])new double[]{0.2, 0.4, 0.6, 0.4, 0.2});
    RealVector remainingWeights = MatrixUtils.createRealVector((double[])new double[]{0.02, 0.04, 0.06, 0.04, 0.02});

    ShapStatisticsTest() {
    }

    @Test
    void getNumSubsetSizesTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        Assertions.assertEquals((int)this.numSubsetSizes, (int)shapStats.getNumSubsetSizes());
    }

    @Test
    void getNumSubsetsAtSizeTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        Assertions.assertArrayEquals((int[])this.numSubsetsAtSize, (int[])shapStats.getNumSubsetsAtSize());
    }

    @Test
    void getLargestPairedSubsetSizeTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        Assertions.assertEquals((int)this.largestPairedSubsetSize, (int)shapStats.getLargestPairedSubsetSize());
    }

    @Test
    void numFullSubsetsTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        Assertions.assertEquals((int)0, (int)shapStats.getNumFullSubsets());
        shapStats.incrementNumFullSubsets();
        Assertions.assertEquals((int)1, (int)shapStats.getNumFullSubsets());
        shapStats.incrementNumFullSubsets();
        Assertions.assertEquals((int)2, (int)shapStats.getNumFullSubsets());
    }

    @Test
    void weightOfSubsetSizeTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        shapStats.setWeightOfSubsetSize(this.weightOfSubsetSize);
        Assertions.assertArrayEquals((double[])this.weightOfSubsetSize.toArray(), (double[])shapStats.getWeightOfSubsetSize().toArray());
    }

    @Test
    void remainingWeightsTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        shapStats.setRemainingWeights(this.remainingWeights);
        Assertions.assertArrayEquals((double[])this.remainingWeights.toArray(), (double[])shapStats.getRemainingWeights().toArray());
    }

    @Test
    void finalRemainingWeightsTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        shapStats.setFinalRemainingWeights(this.finalRemainingWeights);
        Assertions.assertArrayEquals((double[])this.finalRemainingWeights.toArray(), (double[])shapStats.getFinalRemainingWeights().toArray());
    }

    @Test
    void numSamplesRemainingTest() {
        ShapStatistics shapStats = new ShapStatistics(this.numSubsetSizes, this.largestPairedSubsetSize, this.numSubsetsAtSize, this.numSamplesRemaining);
        Assertions.assertEquals((int)this.numSamplesRemaining, (int)shapStats.getNumSamplesRemaining());
        shapStats.decreaseNumSamplesRemainingBy(3);
        Assertions.assertEquals((int)(this.numSamplesRemaining - 3), (int)shapStats.getNumSamplesRemaining());
        shapStats.decreaseNumSamplesRemainingBy(-5);
        Assertions.assertEquals((int)(this.numSamplesRemaining - 3 + 5), (int)shapStats.getNumSamplesRemaining());
    }
}

