/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.HashMap;
import java.util.List;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.utils.RandomChoice;

class RandomChoiceTest {
    List<String> obj = List.of("a", "b", "c", "d", "e");

    RandomChoiceTest() {
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testOnlyOneWeight(int seed) {
        Random rn = new Random();
        rn.setSeed(seed);
        List<String> output = List.of("c", "c", "c");
        List<Double> weights = List.of(Double.valueOf(0.0), Double.valueOf(0.0), Double.valueOf(1.0), Double.valueOf(0.0), Double.valueOf(0.0));
        RandomChoice rc = new RandomChoice(this.obj, weights);
        Assertions.assertEquals(output, (Object)rc.sample(3, rn));
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testTwoWeight(int seed) {
        Random rn = new Random();
        rn.setSeed(seed);
        List<Double> weights = List.of(Double.valueOf(1.0), Double.valueOf(0.0), Double.valueOf(1.0), Double.valueOf(0.0), Double.valueOf(0.0));
        RandomChoice rc = new RandomChoice(this.obj, weights);
        List sample = rc.sample(5, rn);
        for (int i = 0; i < sample.size(); ++i) {
            Assertions.assertTrue((((String)sample.get(i)).equals("a") || ((String)sample.get(i)).equals("c") ? 1 : 0) != 0);
        }
    }

    @Test
    void weightMismatch() {
        List<Double> weights = List.of(Double.valueOf(1.0), Double.valueOf(1.0), Double.valueOf(0.0));
        Assertions.assertThrows(IllegalArgumentException.class, () -> new RandomChoice(this.obj, weights));
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testUniform(int seed) {
        RandomChoice rc = new RandomChoice(this.obj);
        Random rn = new Random();
        rn.setSeed(seed);
        for (int test = 0; test < 100; ++test) {
            List sample = rc.sample(1000, rn);
            HashMap<String, Integer> results = new HashMap<String, Integer>();
            for (String ith : sample) {
                results.putIfAbsent(ith, 0);
                results.put(ith, (Integer)results.get(ith) + 1);
            }
            for (String ith : this.obj) {
                Assertions.assertTrue(((Integer)results.get(ith) > 70 ? 1 : 0) != 0);
                Assertions.assertTrue(((Integer)results.get(ith) < 324 ? 1 : 0) != 0);
            }
        }
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testMultiWeight(int seed) {
        List<Double> weights = List.of(Double.valueOf(5.0), Double.valueOf(4.0), Double.valueOf(3.0), Double.valueOf(2.0), Double.valueOf(1.0));
        RandomChoice rc = new RandomChoice(this.obj, weights);
        Random rn = new Random();
        rn.setSeed(seed);
        for (int test = 0; test < 100; ++test) {
            List sample = rc.sample(1000, rn);
            HashMap<String, Integer> results = new HashMap<String, Integer>();
            for (String ith : sample) {
                results.putIfAbsent(ith, 0);
                results.put(ith, (Integer)results.get(ith) + 1);
            }
            Assertions.assertTrue(((Integer)results.get("a") > 171 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("a") < 475 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("b") > 118 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("b") < 401 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("c") > 70 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("c") < 324 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("d") > 28 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("d") < 242 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("e") > 28 ? 1 : 0) != 0);
            Assertions.assertTrue(((Integer)results.get("e") < 151 ? 1 : 0) != 0);
        }
    }
}

