/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.model.Dataset;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

class DatasetTest {
    DatasetTest() {
    }

    @Test
    void testEmpty() {
        ArrayList predictions = new ArrayList();
        Dataset dataset = new Dataset(predictions);
        Assertions.assertThat((List)dataset.getData()).isEmpty();
        Assertions.assertThat((List)dataset.getInputs()).isEmpty();
        Assertions.assertThat((List)dataset.getOutputs()).isEmpty();
    }

    @Test
    void testNotEmpty() {
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        predictions.add(new SimplePrediction(new PredictionInput(List.of(TestUtils.getMockedNumericFeature())), new PredictionOutput(List.of(new Output("name", Type.UNDEFINED)))));
        Dataset dataset = new Dataset(predictions);
        Assertions.assertThat((List)dataset.getData()).isNotEmpty();
        Assertions.assertThat((List)dataset.getInputs()).isNotEmpty();
        Assertions.assertThat((List)dataset.getOutputs()).isNotEmpty();
    }

    @Test
    void testInputFilter() {
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        predictions.add(new SimplePrediction(new PredictionInput(List.of(TestUtils.getMockedNumericFeature())), new PredictionOutput(List.of(new Output("name", Type.UNDEFINED)))));
        Dataset filteredDataset1 = new Dataset(predictions).filterByInput(pi -> pi.getFeatures().size() == 1);
        Assertions.assertThat((List)filteredDataset1.getData()).isNotEmpty();
        Assertions.assertThat((List)filteredDataset1.getInputs()).isNotEmpty();
        Assertions.assertThat((List)filteredDataset1.getOutputs()).isNotEmpty();
        Dataset filteredDataset2 = new Dataset(predictions).filterByInput(pi -> pi.getFeatures().size() == 2);
        Assertions.assertThat((List)filteredDataset2.getData()).isEmpty();
        Assertions.assertThat((List)filteredDataset2.getInputs()).isEmpty();
        Assertions.assertThat((List)filteredDataset2.getOutputs()).isEmpty();
    }

    @Test
    void testOutFilter() {
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        predictions.add(new SimplePrediction(new PredictionInput(List.of(TestUtils.getMockedNumericFeature())), new PredictionOutput(List.of(new Output("name", Type.UNDEFINED)))));
        Dataset filteredDataset1 = new Dataset(predictions).filterByOutput(po -> po.getOutputs().size() == 1);
        Assertions.assertThat((List)filteredDataset1.getData()).isNotEmpty();
        Assertions.assertThat((List)filteredDataset1.getInputs()).isNotEmpty();
        Assertions.assertThat((List)filteredDataset1.getOutputs()).isNotEmpty();
        Dataset filteredDataset2 = new Dataset(predictions).filterByOutput(po -> po.getOutputs().size() == 2);
        Assertions.assertThat((List)filteredDataset2.getData()).isEmpty();
        Assertions.assertThat((List)filteredDataset2.getInputs()).isEmpty();
        Assertions.assertThat((List)filteredDataset2.getOutputs()).isEmpty();
    }

    private Dataset createDatasetFeatureFiltering(Random random) {
        ArrayList<SimplePrediction> predictions = new ArrayList<SimplePrediction>();
        for (int i = 0; i < 1000; ++i) {
            int j;
            ArrayList<Feature> features = new ArrayList<Feature>();
            for (j = 0; j < 4; ++j) {
                features.add(FeatureFactory.newNumericalFeature((String)("f-" + j), (Number)(random.nextDouble() * 100.0)));
            }
            for (j = 4; j < 8; ++j) {
                features.add(FeatureFactory.newNumericalFeature((String)("f-" + j), (Number)(100.0 + random.nextDouble() * 100.0)));
            }
            features.add(FeatureFactory.newBooleanFeature((String)"f-8", (Boolean)true));
            features.add(FeatureFactory.newTextFeature((String)"f-9", (String)UUID.randomUUID().toString()));
            PredictionOutput output = new PredictionOutput(List.of(new Output("output", Type.BOOLEAN, new Value((Object)false), 1.0)));
            predictions.add(new SimplePrediction(new PredictionInput(features), output));
        }
        return new Dataset(predictions);
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testFilterByFeatureName(int seed) {
        Random random = new Random(seed);
        Dataset dataset = this.createDatasetFeatureFiltering(random);
        org.junit.jupiter.api.Assertions.assertEquals((int)1000, (int)dataset.getData().size());
        int index = random.nextInt(dataset.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals((int)10, (int)((Prediction)dataset.getData().get(index)).getInput().getFeatures().size());
        Predicate<Feature> featureName = f -> f.getName().equals("f-3");
        Dataset filteredDataset = dataset.filterByFeature(featureName.negate());
        org.junit.jupiter.api.Assertions.assertEquals((int)1000, (int)filteredDataset.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals((int)9, (int)((Prediction)filteredDataset.getData().get(index)).getInput().getFeatures().size());
        List names = ((Prediction)filteredDataset.getData().get(index)).getInput().getFeatures().stream().map(Feature::getName).collect(Collectors.toList());
        org.junit.jupiter.api.Assertions.assertFalse((boolean)names.contains("f-3"));
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testFilterByFeatureType(int seed) {
        Random random = new Random(seed);
        Dataset dataset = this.createDatasetFeatureFiltering(random);
        org.junit.jupiter.api.Assertions.assertEquals((int)1000, (int)dataset.getData().size());
        int index = random.nextInt(dataset.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals((int)10, (int)((Prediction)dataset.getData().get(index)).getInput().getFeatures().size());
        Predicate<Feature> featureType = f -> f.getType().equals((Object)Type.NUMBER);
        Dataset filteredDataset = dataset.filterByFeature(featureType.negate());
        org.junit.jupiter.api.Assertions.assertEquals((int)1000, (int)filteredDataset.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals((int)2, (int)((Prediction)filteredDataset.getData().get(index)).getInput().getFeatures().size());
        List types = ((Prediction)filteredDataset.getData().get(index)).getInput().getFeatures().stream().map(Feature::getType).collect(Collectors.toList());
        org.junit.jupiter.api.Assertions.assertFalse((boolean)types.contains(Type.NUMBER));
    }

    @ParameterizedTest
    @ValueSource(ints={0, 1, 2})
    void testFilterByFeatureValue(int seed) {
        Random random = new Random(seed);
        Dataset dataset = this.createDatasetFeatureFiltering(random);
        org.junit.jupiter.api.Assertions.assertEquals((int)1000, (int)dataset.getData().size());
        int index = random.nextInt(dataset.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals((int)10, (int)((Prediction)dataset.getData().get(index)).getInput().getFeatures().size());
        Predicate<Feature> featureValue = f -> f.getValue().asNumber() > 100.0;
        Dataset filteredDataset = dataset.filterByFeature(featureValue);
        org.junit.jupiter.api.Assertions.assertEquals((int)1000, (int)filteredDataset.getData().size());
        org.junit.jupiter.api.Assertions.assertEquals((int)4, (int)((Prediction)filteredDataset.getData().get(index)).getInput().getFeatures().size());
        List values = ((Prediction)filteredDataset.getData().get(index)).getInput().getFeatures().stream().map(Feature::getValue).map(Value::asNumber).collect(Collectors.toList());
        org.junit.jupiter.api.Assertions.assertTrue((boolean)values.stream().allMatch(x -> x > 100.0));
        List types = ((Prediction)filteredDataset.getData().get(index)).getInput().getFeatures().stream().map(Feature::getType).collect(Collectors.toList());
        org.junit.jupiter.api.Assertions.assertTrue((boolean)types.stream().allMatch(x -> x.equals((Object)Type.NUMBER)));
    }
}

