/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.enterprise.inject.Instance;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.ThrowingSupplier;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.ExplanationServiceImpl;
import org.kie.kogito.explainability.PredictionProviderFactory;
import org.kie.kogito.explainability.TestUtils;
import org.kie.kogito.explainability.api.BaseExplainabilityResultDto;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityResultDto;
import org.kie.kogito.explainability.api.ExplainabilityStatus;
import org.kie.kogito.explainability.api.FeatureImportanceDto;
import org.kie.kogito.explainability.api.LIMEExplainabilityResultDto;
import org.kie.kogito.explainability.api.SaliencyDto;
import org.kie.kogito.explainability.handlers.CounterfactualExplainerServiceHandler;
import org.kie.kogito.explainability.handlers.LimeExplainerServiceHandler;
import org.kie.kogito.explainability.handlers.LocalExplainerServiceHandlerRegistry;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.models.BaseExplainabilityRequest;
import org.kie.kogito.explainability.models.ModelIdentifier;
import org.kie.kogito.tracing.typedvalue.TypedValue;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

class ExplanationServiceImplTest {
    Instance instance;
    ExplanationServiceImpl explanationService;
    LimeExplainer limeExplainerMock;
    LimeExplainerServiceHandler limeExplainerServiceHandlerMock;
    CounterfactualExplainer cfExplainerMock;
    CounterfactualExplainerServiceHandler cfExplainerServiceHandlerMock;
    LocalExplainerServiceHandlerRegistry explainerServiceHandlerRegistryMock;
    PredictionProvider predictionProviderMock;
    Consumer<BaseExplainabilityResultDto> callbackMock;

    ExplanationServiceImplTest() {
    }

    @BeforeEach
    void init() {
        this.instance = (Instance)Mockito.mock(Instance.class);
        this.limeExplainerMock = (LimeExplainer)Mockito.mock(LimeExplainer.class);
        this.cfExplainerMock = (CounterfactualExplainer)Mockito.mock(CounterfactualExplainer.class);
        PredictionProviderFactory predictionProviderFactory = (PredictionProviderFactory)Mockito.mock(PredictionProviderFactory.class);
        this.explainerServiceHandlerRegistryMock = new LocalExplainerServiceHandlerRegistry(this.instance);
        this.limeExplainerServiceHandlerMock = (LimeExplainerServiceHandler)Mockito.spy((Object)new LimeExplainerServiceHandler(this.limeExplainerMock, predictionProviderFactory));
        this.cfExplainerServiceHandlerMock = (CounterfactualExplainerServiceHandler)Mockito.spy((Object)new CounterfactualExplainerServiceHandler(this.cfExplainerMock, predictionProviderFactory));
        this.predictionProviderMock = (PredictionProvider)Mockito.mock(PredictionProvider.class);
        this.callbackMock = (Consumer)Mockito.mock(Consumer.class);
        this.explanationService = new ExplanationServiceImpl(this.explainerServiceHandlerRegistryMock);
        Mockito.when((Object)predictionProviderFactory.createPredictionProvider((String)ArgumentMatchers.any(), (ModelIdentifier)ArgumentMatchers.any(), (Map)ArgumentMatchers.any())).thenReturn((Object)this.predictionProviderMock);
    }

    @Test
    void testLIMEExplainAsyncSucceeded() {
        this.testLIMEExplainAsyncSuccess((ThrowingSupplier<BaseExplainabilityResultDto>)((ThrowingSupplier)() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())));
    }

    @Test
    void testLIMEExplainAsyncSucceededWithoutCallback() {
        this.testLIMEExplainAsyncSuccess((ThrowingSupplier<BaseExplainabilityResultDto>)((ThrowingSupplier)() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.LIME_REQUEST).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())));
    }

    void testLIMEExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResultDto> invocation) {
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock));
        Mockito.when((Object)this.limeExplainerMock.explainAsync((Prediction)ArgumentMatchers.any(Prediction.class), (PredictionProvider)ArgumentMatchers.eq((Object)this.predictionProviderMock), (Consumer)ArgumentMatchers.any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(TestUtils.SALIENCY_MAP));
        BaseExplainabilityResultDto resultDto = (BaseExplainabilityResultDto)Assertions.assertDoesNotThrow(invocation);
        Assertions.assertNotNull((Object)resultDto);
        Assertions.assertTrue((boolean)(resultDto instanceof LIMEExplainabilityResultDto));
        LIMEExplainabilityResultDto limeResultDto = (LIMEExplainabilityResultDto)resultDto;
        Assertions.assertEquals((Object)TestUtils.EXECUTION_ID, (Object)limeResultDto.getExecutionId());
        Assertions.assertSame((Object)ExplainabilityStatus.SUCCEEDED, (Object)limeResultDto.getStatus());
        Assertions.assertNull((Object)limeResultDto.getStatusDetails());
        Assertions.assertEquals((int)TestUtils.SALIENCY_MAP.size(), (int)limeResultDto.getSaliencies().size());
        Assertions.assertTrue((boolean)limeResultDto.getSaliencies().containsKey("key"));
        SaliencyDto saliencyDto = (SaliencyDto)limeResultDto.getSaliencies().get("key");
        Assertions.assertEquals((int)TestUtils.SALIENCY.getPerFeatureImportance().size(), (int)saliencyDto.getFeatureImportance().size());
        FeatureImportanceDto featureImportanceDto1 = (FeatureImportanceDto)saliencyDto.getFeatureImportance().get(0);
        Assertions.assertEquals((Object)TestUtils.FEATURE_IMPORTANCE_1.getFeature().getName(), (Object)featureImportanceDto1.getFeatureName());
        Assertions.assertEquals((double)TestUtils.FEATURE_IMPORTANCE_1.getScore(), (double)featureImportanceDto1.getScore(), (double)0.01);
    }

    @Test
    void testCounterfactualsExplainAsyncSucceeded() {
        this.testCounterfactualsExplainAsyncSuccess((ThrowingSupplier<BaseExplainabilityResultDto>)((ThrowingSupplier)() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.COUNTERFACTUAL_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())));
    }

    @Test
    void testCounterfactualsExplainAsyncSucceededWithoutCallback() {
        this.testCounterfactualsExplainAsyncSuccess((ThrowingSupplier<BaseExplainabilityResultDto>)((ThrowingSupplier)() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.COUNTERFACTUAL_REQUEST).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())));
    }

    void testCounterfactualsExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResultDto> invocation) {
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(this.cfExplainerServiceHandlerMock));
        Mockito.when((Object)this.cfExplainerMock.explainAsync((Prediction)ArgumentMatchers.any(Prediction.class), (PredictionProvider)ArgumentMatchers.eq((Object)this.predictionProviderMock), (Consumer)ArgumentMatchers.any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(TestUtils.COUNTERFACTUAL_RESULT));
        BaseExplainabilityResultDto resultDto = (BaseExplainabilityResultDto)Assertions.assertDoesNotThrow(invocation);
        Assertions.assertNotNull((Object)resultDto);
        Assertions.assertTrue((boolean)(resultDto instanceof CounterfactualExplainabilityResultDto));
        CounterfactualExplainabilityResultDto counterfactualResultDto = (CounterfactualExplainabilityResultDto)resultDto;
        Assertions.assertEquals((Object)TestUtils.EXECUTION_ID, (Object)counterfactualResultDto.getExecutionId());
        Assertions.assertEquals((Object)TestUtils.COUNTERFACTUAL_ID, (Object)counterfactualResultDto.getCounterfactualId());
        Assertions.assertSame((Object)ExplainabilityStatus.SUCCEEDED, (Object)counterfactualResultDto.getStatus());
        Assertions.assertNull((Object)counterfactualResultDto.getStatusDetails());
        Assertions.assertEquals((int)TestUtils.COUNTERFACTUAL_RESULT.getEntities().size(), (int)counterfactualResultDto.getInputs().size());
        Assertions.assertEquals((int)TestUtils.COUNTERFACTUAL_RESULT.getOutput().size(), (int)counterfactualResultDto.getOutputs().size());
        Assertions.assertTrue((boolean)counterfactualResultDto.getOutputs().containsKey("output1"));
        TypedValue value = (TypedValue)counterfactualResultDto.getOutputs().get("output1");
        Assertions.assertTrue((boolean)value.isUnit());
        Assertions.assertEquals((Object)Double.class.getSimpleName(), (Object)value.toUnit().getType());
        Assertions.assertEquals((double)555.0, (double)value.toUnit().getValue().asDouble());
    }

    @Test
    void testServiceCallFailed() {
        String errorMessage = "Something bad happened";
        RuntimeException exception = new RuntimeException(errorMessage);
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock));
        ((LimeExplainerServiceHandler)Mockito.doThrow((Throwable[])new Throwable[]{exception}).when((Object)this.limeExplainerServiceHandlerMock)).supports((Class)ArgumentMatchers.any());
        Assertions.assertThrows(RuntimeException.class, () -> this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    }

    @Test
    void testServiceCallFailedNoMatchingServiceHandlers() {
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(new Object[0]));
        Assertions.assertThrows(IllegalArgumentException.class, () -> this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    }

    @Test
    void testLIMEExplainAsyncFailed() {
        String errorMessage = "Something bad happened";
        RuntimeException exception = new RuntimeException(errorMessage);
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock));
        Mockito.when((Object)this.limeExplainerMock.explainAsync((Prediction)ArgumentMatchers.any(Prediction.class), (PredictionProvider)ArgumentMatchers.eq((Object)this.predictionProviderMock), (Consumer)ArgumentMatchers.any(Consumer.class))).thenThrow(new Throwable[]{exception});
        BaseExplainabilityResultDto resultDto = (BaseExplainabilityResultDto)Assertions.assertDoesNotThrow(() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
        Assertions.assertNotNull((Object)resultDto);
        Assertions.assertTrue((boolean)(resultDto instanceof LIMEExplainabilityResultDto));
        LIMEExplainabilityResultDto exceptionResultDto = (LIMEExplainabilityResultDto)resultDto;
        Assertions.assertEquals((Object)TestUtils.EXECUTION_ID, (Object)exceptionResultDto.getExecutionId());
        Assertions.assertSame((Object)ExplainabilityStatus.FAILED, (Object)exceptionResultDto.getStatus());
        Assertions.assertEquals((Object)errorMessage, (Object)exceptionResultDto.getStatusDetails());
    }

    @Test
    void testCounterfactualsxplainAsyncFailed() {
        String errorMessage = "Something bad happened";
        RuntimeException exception = new RuntimeException(errorMessage);
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(this.cfExplainerServiceHandlerMock));
        Mockito.when((Object)this.cfExplainerMock.explainAsync((Prediction)ArgumentMatchers.any(Prediction.class), (PredictionProvider)ArgumentMatchers.eq((Object)this.predictionProviderMock), (Consumer)ArgumentMatchers.any(Consumer.class))).thenThrow(new Throwable[]{exception});
        BaseExplainabilityResultDto resultDto = (BaseExplainabilityResultDto)Assertions.assertDoesNotThrow(() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.COUNTERFACTUAL_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
        Assertions.assertNotNull((Object)resultDto);
        Assertions.assertTrue((boolean)(resultDto instanceof CounterfactualExplainabilityResultDto));
        CounterfactualExplainabilityResultDto exceptionResultDto = (CounterfactualExplainabilityResultDto)resultDto;
        Assertions.assertEquals((Object)TestUtils.EXECUTION_ID, (Object)exceptionResultDto.getExecutionId());
        Assertions.assertSame((Object)ExplainabilityStatus.FAILED, (Object)exceptionResultDto.getStatus());
        Assertions.assertEquals((Object)errorMessage, (Object)exceptionResultDto.getStatusDetails());
    }

    @Test
    void testServiceHandlerLookupLIME() {
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock, this.cfExplainerServiceHandlerMock));
        Mockito.when((Object)this.limeExplainerMock.explainAsync((Prediction)ArgumentMatchers.any(), (PredictionProvider)ArgumentMatchers.any(), (Consumer)ArgumentMatchers.any())).thenReturn(CompletableFuture.completedFuture(TestUtils.SALIENCY_MAP));
        BaseExplainabilityResultDto resultDto = (BaseExplainabilityResultDto)Assertions.assertDoesNotThrow(() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
        Assertions.assertNotNull((Object)resultDto);
        Assertions.assertTrue((boolean)(resultDto instanceof LIMEExplainabilityResultDto));
    }

    @Test
    void testServiceHandlerLookupCounterfactuals() {
        Mockito.when((Object)this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock, this.cfExplainerServiceHandlerMock));
        Mockito.when((Object)this.cfExplainerMock.explainAsync((Prediction)ArgumentMatchers.any(), (PredictionProvider)ArgumentMatchers.any(), (Consumer)ArgumentMatchers.any())).thenReturn(CompletableFuture.completedFuture(TestUtils.COUNTERFACTUAL_RESULT));
        BaseExplainabilityResultDto resultDto = (BaseExplainabilityResultDto)Assertions.assertDoesNotThrow(() -> (BaseExplainabilityResultDto)this.explanationService.explainAsync((BaseExplainabilityRequest)TestUtils.COUNTERFACTUAL_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
        Assertions.assertNotNull((Object)resultDto);
        Assertions.assertTrue((boolean)(resultDto instanceof CounterfactualExplainabilityResultDto));
    }
}

