/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.global.pdp;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;

class PartialDependencePlotExplainerTest {
    PartialDependencePlotExplainerTest() {
    }

    private PredictionProviderMetadata getMetadata(final Random random) {
        return new PredictionProviderMetadata(){

            public DataDistribution getDataDistribution() {
                return DataUtils.generateRandomDataDistribution((int)3, (int)100, (Random)random);
            }

            public PredictionInput getInputShape() {
                LinkedList<Feature> features = new LinkedList<Feature>();
                features.add(FeatureFactory.newNumericalFeature((String)"f0", (Number)0));
                features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)0));
                features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)0));
                return new PredictionInput(features);
            }

            public PredictionOutput getOutputShape() {
                LinkedList<Output> outputs = new LinkedList<Output>();
                outputs.add(new Output("sum-but0", Type.BOOLEAN, new Value((Object)false), 0.0));
                return new PredictionOutput(outputs);
            }
        };
    }

    @Test
    void testPdpNumericClassifier() throws Exception {
        Random random = new Random();
        for (int seed = 0; seed < 5; ++seed) {
            random.setSeed(seed);
            PredictionProvider modelInfo = TestUtils.getSumSkipModel(0);
            PartialDependencePlotExplainer partialDependencePlotProvider = new PartialDependencePlotExplainer();
            List pdps = partialDependencePlotProvider.explainFromMetadata(modelInfo, this.getMetadata(random));
            org.junit.jupiter.api.Assertions.assertNotNull((Object)pdps);
            for (PartialDependenceGraph pdp : pdps) {
                org.junit.jupiter.api.Assertions.assertNotNull((Object)pdp.getFeature());
                org.junit.jupiter.api.Assertions.assertNotNull((Object)pdp.getX());
                org.junit.jupiter.api.Assertions.assertNotNull((Object)pdp.getY());
                org.junit.jupiter.api.Assertions.assertEquals((int)pdp.getX().length, (int)pdp.getY().length);
                this.assertGraph(pdp);
            }
            PartialDependenceGraph fixedFeatureGraph = (PartialDependenceGraph)pdps.get(0);
            org.junit.jupiter.api.Assertions.assertEquals((long)1L, (long)Arrays.stream(fixedFeatureGraph.getY()).distinct().count());
            Assertions.assertThat((long)Arrays.stream(((PartialDependenceGraph)pdps.get(1)).getY()).distinct().count()).isGreaterThan(1L);
            Assertions.assertThat((long)Arrays.stream(((PartialDependenceGraph)pdps.get(2)).getY()).distinct().count()).isGreaterThan(1L);
        }
    }

    private void assertGraph(PartialDependenceGraph pdp) {
        for (int i = 0; i < pdp.getX().length; ++i) {
            org.junit.jupiter.api.Assertions.assertNotEquals((double)Double.NaN, (double)pdp.getY()[i]);
            if (i <= 0) continue;
            org.junit.jupiter.api.Assertions.assertTrue((pdp.getX()[i] >= pdp.getX()[i - 1] ? 1 : 0) != 0);
        }
    }

    @Test
    void testBrokenPredict() {
        Config.INSTANCE.setAsyncTimeout(1L);
        Config.INSTANCE.setAsyncTimeUnit(TimeUnit.MILLISECONDS);
        Random random = new Random();
        for (int seed = 0; seed < 5; ++seed) {
            random.setSeed(seed);
            PartialDependencePlotExplainer partialDependencePlotProvider = new PartialDependencePlotExplainer();
            PredictionProvider brokenProvider = inputs -> CompletableFuture.supplyAsync(() -> {
                try {
                    Thread.sleep(1000L);
                    return Collections.emptyList();
                }
                catch (InterruptedException e) {
                    throw new RuntimeException("this is a test");
                }
            });
            org.junit.jupiter.api.Assertions.assertThrows(TimeoutException.class, () -> partialDependencePlotProvider.explainFromMetadata(brokenProvider, this.getMetadata(random)));
        }
        Config.INSTANCE.setAsyncTimeout(5L);
        Config.INSTANCE.setAsyncTimeUnit(Config.DEFAULT_ASYNC_TIMEUNIT);
    }
}

