/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.global.pdp;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.FakeRandom;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;

class PartialDependencePlotExplainerTest {
    PartialDependencePlotExplainer partialDependencePlotProvider = new PartialDependencePlotExplainer();
    PredictionProviderMetadata metadata = new PredictionProviderMetadata(){

        public DataDistribution getDataDistribution() {
            return DataUtils.generateRandomDataDistribution((int)3, (int)100, (Random)new FakeRandom());
        }

        public PredictionInput getInputShape() {
            LinkedList<Feature> features = new LinkedList<Feature>();
            features.add(FeatureFactory.newNumericalFeature((String)"f0", (Number)0));
            features.add(FeatureFactory.newNumericalFeature((String)"f1", (Number)0));
            features.add(FeatureFactory.newNumericalFeature((String)"f2", (Number)0));
            return new PredictionInput(features);
        }

        public PredictionOutput getOutputShape() {
            LinkedList<Output> outputs = new LinkedList<Output>();
            outputs.add(new Output("spam", Type.BOOLEAN, new Value((Object)false), 0.0));
            return new PredictionOutput(outputs);
        }
    };

    PartialDependencePlotExplainerTest() {
    }

    @Test
    void testPdpNumericClassifier() throws Exception {
        PredictionProvider modelInfo = TestUtils.getSumSkipModel(0);
        List pdps = this.partialDependencePlotProvider.explain(modelInfo, this.metadata);
        Assertions.assertNotNull((Object)pdps);
        for (PartialDependenceGraph pdp : pdps) {
            Assertions.assertNotNull((Object)pdp.getFeature());
            Assertions.assertNotNull((Object)pdp.getX());
            Assertions.assertNotNull((Object)pdp.getY());
            Assertions.assertEquals((int)pdp.getX().length, (int)pdp.getY().length);
            this.assertGraph(pdp);
        }
        PartialDependenceGraph fixedFeatureGraph = (PartialDependenceGraph)pdps.get(0);
        Assertions.assertEquals((long)1L, (long)Arrays.stream(fixedFeatureGraph.getY()).distinct().count());
        Assertions.assertArrayEquals((double[])((PartialDependenceGraph)pdps.get(1)).getY(), (double[])((PartialDependenceGraph)pdps.get(2)).getY());
    }

    private void assertGraph(PartialDependenceGraph pdp) {
        for (int i = 0; i < pdp.getX().length; ++i) {
            Assertions.assertNotEquals((double)Double.NaN, (double)pdp.getY()[i]);
            if (i <= 0) continue;
            Assertions.assertTrue((pdp.getX()[i] > pdp.getX()[i - 1] ? 1 : 0) != 0);
        }
    }

    @Test
    void testBrokenPredict() {
        Config.INSTANCE.setAsyncTimeout(1L);
        Config.INSTANCE.setAsyncTimeUnit(TimeUnit.MILLISECONDS);
        PredictionProvider brokenProvider = inputs -> CompletableFuture.supplyAsync(() -> {
            try {
                Thread.sleep(1000L);
                return Collections.emptyList();
            }
            catch (InterruptedException e) {
                throw new RuntimeException("this is a test");
            }
        });
        Assertions.assertThrows(TimeoutException.class, () -> this.partialDependencePlotProvider.explain(brokenProvider, this.metadata));
        Config.INSTANCE.setAsyncTimeout(5L);
        Config.INSTANCE.setAsyncTimeUnit(Config.DEFAULT_ASYNC_TIMEUNIT);
    }
}

