/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.model;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;

class SaliencyTest {
    SaliencyTest() {
    }

    @Test
    void testGetTopFeatures() {
        ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.19));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), -0.44));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.04));
        Output output = new Output("name", Type.NUMBER);
        Saliency saliency = new Saliency(output, fis);
        List topFeatures = saliency.getTopFeatures(2);
        Assertions.assertNotNull((Object)topFeatures);
        Assertions.assertEquals((int)2, (int)topFeatures.size());
        List collect = topFeatures.stream().map(FeatureImportance::getScore).collect(Collectors.toList());
        Assertions.assertTrue((boolean)collect.contains(-0.44));
        Assertions.assertTrue((boolean)collect.contains(0.19));
    }

    @Test
    void testGetPositiveFeatures() {
        ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.19));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), -0.44));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.04));
        Output output = new Output("name", Type.NUMBER);
        Saliency saliency = new Saliency(output, fis);
        List topFeatures = saliency.getPositiveFeatures(2);
        Assertions.assertNotNull((Object)topFeatures);
        Assertions.assertEquals((int)2, (int)topFeatures.size());
        List collect = topFeatures.stream().map(FeatureImportance::getScore).collect(Collectors.toList());
        Assertions.assertTrue((boolean)collect.contains(0.04));
        Assertions.assertTrue((boolean)collect.contains(0.19));
    }

    @Test
    void testGetNegativeFeatures() {
        ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.19));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), -0.44));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.04));
        Output output = new Output("name", Type.NUMBER);
        Saliency saliency = new Saliency(output, fis);
        List topFeatures = saliency.getNegativeFeatures(2);
        Assertions.assertNotNull((Object)topFeatures);
        Assertions.assertEquals((int)1, (int)topFeatures.size());
        List collect = topFeatures.stream().map(FeatureImportance::getScore).collect(Collectors.toList());
        Assertions.assertTrue((boolean)collect.contains(-0.44));
    }

    @Test
    void testSameImportantFeatures() {
        ArrayList<FeatureImportance> fis = new ArrayList<FeatureImportance>();
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.1));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.1));
        fis.add(new FeatureImportance(TestUtils.getMockedNumericFeature(), 0.1));
        Output output = new Output("name", Type.NUMBER);
        Saliency saliency = new Saliency(output, fis);
        List topFeatures = saliency.getTopFeatures(2);
        Assertions.assertNotNull((Object)topFeatures);
        Assertions.assertEquals((int)2, (int)topFeatures.size());
        List collect = topFeatures.stream().map(FeatureImportance::getScore).collect(Collectors.toList());
        Assertions.assertTrue((boolean)collect.contains(0.1));
        Assertions.assertTrue((boolean)collect.contains(0.1));
        List negativeFeatures = saliency.getNegativeFeatures(2);
        Assertions.assertNotNull((Object)negativeFeatures);
        Assertions.assertTrue((boolean)negativeFeatures.isEmpty());
        List positiveFeatures = saliency.getPositiveFeatures(2);
        Assertions.assertNotNull((Object)positiveFeatures);
        Assertions.assertEquals((int)2, (int)positiveFeatures.size());
    }

    @Test
    void testMergeSaliencyMaps() {
        ArrayList<FeatureImportance> fis1 = new ArrayList<FeatureImportance>();
        fis1.add(new FeatureImportance(FeatureFactory.newTextFeature((String)"f1", (String)"foo"), 0.1));
        fis1.add(new FeatureImportance(FeatureFactory.newTextFeature((String)"f2", (String)"bar"), -0.4));
        fis1.add(new FeatureImportance(FeatureFactory.newNumericalFeature((String)"f3", (Number)10), 0.01));
        Output output1 = new Output("out", Type.NUMBER);
        Saliency saliency1 = new Saliency(output1, fis1);
        ArrayList<FeatureImportance> fis2 = new ArrayList<FeatureImportance>();
        fis2.add(new FeatureImportance(FeatureFactory.newTextFeature((String)"f1", (String)"foo"), 0.2));
        fis2.add(new FeatureImportance(FeatureFactory.newTextFeature((String)"f2", (String)"bar"), -0.2));
        fis2.add(new FeatureImportance(FeatureFactory.newNumericalFeature((String)"f3", (Number)10), 0.03));
        Output output2 = new Output("out", Type.NUMBER);
        Saliency saliency2 = new Saliency(output2, fis2);
        HashMap<String, Saliency> map1 = new HashMap<String, Saliency>();
        map1.put("out", saliency1);
        HashMap<String, Saliency> map2 = new HashMap<String, Saliency>();
        map2.put("out", saliency2);
        Map merge = Saliency.merge(List.of(map1, map2));
        Assertions.assertNotNull((Object)merge);
        Assertions.assertEquals((int)1, (int)merge.size());
        Saliency mergedSaliency = (Saliency)merge.get("out");
        List perFeatureImportance = mergedSaliency.getPerFeatureImportance();
        Assertions.assertNotNull((Object)perFeatureImportance);
        Assertions.assertEquals((int)3, (int)perFeatureImportance.size());
        Assertions.assertEquals((double)0.15, (double)((FeatureImportance)perFeatureImportance.get(0)).getScore(), (double)0.001);
        Assertions.assertEquals((double)-0.3, (double)((FeatureImportance)perFeatureImportance.get(1)).getScore(), (double)0.001);
        Assertions.assertEquals((double)0.02, (double)((FeatureImportance)perFeatureImportance.get(2)).getScore(), (double)0.001);
    }
}

