/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.shap;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.local.shap.ShapSyntheticDataSample;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.PredictionInput;

class ShapSyntheticDataSampleTest {
    ShapSyntheticDataSampleTest() {
    }

    private ShapSyntheticDataSample generateShapSample() {
        ArrayList<Feature> fs = new ArrayList<Feature>();
        fs.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)-1));
        fs.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)-1));
        fs.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)-1));
        fs.add(FeatureFactory.newNumericalFeature((String)"f4", (Number)-1));
        fs.add(FeatureFactory.newNumericalFeature((String)"f5", (Number)-1));
        PredictionInput pi = new PredictionInput(fs);
        boolean[] mask = new boolean[]{true, true, false, false, true};
        RealMatrix background = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.0, 1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0, 9.0}});
        double weight = 0.5;
        boolean fixed = true;
        return new ShapSyntheticDataSample(pi, mask, background, weight, fixed);
    }

    private List<PredictionInput> generateExpectedSynthData() {
        ArrayList<Feature> synthFeatures1 = new ArrayList<Feature>();
        ArrayList<Feature> synthFeatures2 = new ArrayList<Feature>();
        ArrayList<PredictionInput> synthData = new ArrayList<PredictionInput>();
        synthFeatures1.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)-1));
        synthFeatures1.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)-1));
        synthFeatures1.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)2.0));
        synthFeatures1.add(FeatureFactory.newNumericalFeature((String)"f4", (Number)3.0));
        synthFeatures1.add(FeatureFactory.newNumericalFeature((String)"f5", (Number)-1));
        synthData.add(new PredictionInput(synthFeatures1));
        synthFeatures2.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)-1));
        synthFeatures2.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)-1));
        synthFeatures2.add(FeatureFactory.newNumericalFeature((String)"f3", (Number)7.0));
        synthFeatures2.add(FeatureFactory.newNumericalFeature((String)"f4", (Number)8.0));
        synthFeatures2.add(FeatureFactory.newNumericalFeature((String)"f5", (Number)-1));
        synthData.add(new PredictionInput(synthFeatures2));
        return synthData;
    }

    @Test
    void testSyntheticCreation() {
        ShapSyntheticDataSample shapSamp = this.generateShapSample();
        List<PredictionInput> expectedSynth = this.generateExpectedSynthData();
        List generatedSynth = shapSamp.getSyntheticData();
        for (int i = 0; i < generatedSynth.size(); ++i) {
            List expectedFeatures = expectedSynth.get(i).getFeatures();
            List generatedFeatures = ((PredictionInput)generatedSynth.get(i)).getFeatures();
            for (int j = 0; j < generatedFeatures.size(); ++j) {
                Assertions.assertEquals(generatedFeatures.get(j), expectedFeatures.get(j));
            }
        }
    }

    @Test
    void testIsFixed() {
        ShapSyntheticDataSample shapSamp = this.generateShapSample();
        Assertions.assertTrue((boolean)shapSamp.isFixed());
    }

    @Test
    void testGetMask() {
        ShapSyntheticDataSample shapSamp = this.generateShapSample();
        boolean[] mask = new boolean[]{true, true, false, false, true};
        Assertions.assertArrayEquals((boolean[])mask, (boolean[])shapSamp.getMask());
    }

    @Test
    void testWeight() {
        ShapSyntheticDataSample shapSamp = this.generateShapSample();
        Assertions.assertEquals((double)0.5, (double)shapSamp.getWeight());
        shapSamp.incrementWeight();
        Assertions.assertEquals((double)1.5, (double)shapSamp.getWeight());
        shapSamp.setWeight(2.5);
        Assertions.assertEquals((double)2.5, (double)shapSamp.getWeight());
    }
}

