/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.handlers;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import org.kie.kogito.explainability.ConversionUtils;
import org.kie.kogito.explainability.PredictionProviderFactory;
import org.kie.kogito.explainability.api.BaseExplainabilityRequestDto;
import org.kie.kogito.explainability.api.BaseExplainabilityResultDto;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityRequestDto;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityResultDto;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionDto;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainDto;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureDto;
import org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionFeatureDomain;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.models.BaseExplainabilityRequest;
import org.kie.kogito.explainability.models.CounterfactualExplainabilityRequest;
import org.kie.kogito.explainability.models.ModelIdentifier;
import org.kie.kogito.tracing.typedvalue.CollectionValue;
import org.kie.kogito.tracing.typedvalue.StructureValue;
import org.kie.kogito.tracing.typedvalue.TypedValue;

@ApplicationScoped
public class CounterfactualExplainerServiceHandler
implements LocalExplainerServiceHandler<CounterfactualResult, CounterfactualExplainabilityRequest, CounterfactualExplainabilityRequestDto> {
    private final CounterfactualExplainer explainer;
    private final PredictionProviderFactory predictionProviderFactory;

    @Inject
    public CounterfactualExplainerServiceHandler(CounterfactualExplainer explainer, PredictionProviderFactory predictionProviderFactory) {
        this.explainer = explainer;
        this.predictionProviderFactory = predictionProviderFactory;
    }

    @Override
    public <T extends BaseExplainabilityRequest> boolean supports(Class<T> type) {
        return CounterfactualExplainabilityRequest.class.isAssignableFrom(type);
    }

    @Override
    public <T extends BaseExplainabilityRequestDto> boolean supportsDto(Class<T> type) {
        return CounterfactualExplainabilityRequestDto.class.isAssignableFrom(type);
    }

    @Override
    public CounterfactualExplainabilityRequest explainabilityRequestFrom(CounterfactualExplainabilityRequestDto dto) {
        return new CounterfactualExplainabilityRequest(dto.getExecutionId(), dto.getCounterfactualId(), dto.getServiceUrl(), ModelIdentifier.from(dto.getModelIdentifier()), dto.getOriginalInputs(), dto.getGoals(), dto.getSearchDomains());
    }

    @Override
    public PredictionProvider getPredictionProvider(CounterfactualExplainabilityRequest request) {
        return this.predictionProviderFactory.createPredictionProvider(request.getServiceUrl(), request.getModelIdentifier(), request.getGoals());
    }

    @Override
    public Prediction getPrediction(CounterfactualExplainabilityRequest request) {
        Map<String, CounterfactualSearchDomainDto> searchDomains;
        Map<String, TypedValue> goals;
        Map<String, TypedValue> originalInputs = request.getOriginalInputs();
        if (this.isUnsupportedModel(originalInputs, goals = request.getGoals(), searchDomains = request.getSearchDomains())) {
            throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
        }
        PredictionInput input = new PredictionInput(ConversionUtils.toFeatureList(originalInputs));
        PredictionOutput output = new PredictionOutput(ConversionUtils.toOutputList(goals));
        PredictionFeatureDomain featureDomain = new PredictionFeatureDomain(ConversionUtils.toFeatureDomainList(searchDomains));
        List<Boolean> featureConstraints = ConversionUtils.toFeatureConstraintList(searchDomains);
        return new CounterfactualPrediction(input, output, featureDomain, featureConstraints, null, UUID.fromString(request.getExecutionId()));
    }

    private boolean isUnsupportedModel(Map<String, TypedValue> originalInputs, Map<String, TypedValue> requiredOutputs, Map<String, CounterfactualSearchDomainDto> searchDomains) {
        return this.isUnsupportedTypedValue(originalInputs.values()) || this.isUnsupportedTypedValue(requiredOutputs.values()) || this.isUnsupportedCounterfactualSearchDomain(searchDomains.values());
    }

    private boolean isUnsupportedTypedValue(Collection<TypedValue> typedValues) {
        return typedValues.stream().anyMatch(tv -> tv instanceof StructureValue || tv instanceof CollectionValue);
    }

    private boolean isUnsupportedCounterfactualSearchDomain(Collection<CounterfactualSearchDomainDto> domains) {
        return domains.stream().anyMatch(domain -> domain instanceof CounterfactualSearchDomainStructureDto || domain instanceof CounterfactualSearchDomainCollectionDto);
    }

    @Override
    public BaseExplainabilityResultDto createSucceededResultDto(CounterfactualExplainabilityRequest request, CounterfactualResult result) {
        return this.buildResultDtoFromExplanation(request, result, CounterfactualExplainabilityResultDto.Stage.FINAL);
    }

    @Override
    public BaseExplainabilityResultDto createFailedResultDto(CounterfactualExplainabilityRequest request, Throwable throwable) {
        return CounterfactualExplainabilityResultDto.buildFailed((String)request.getExecutionId(), (String)request.getCounterfactualId(), (String)throwable.getMessage());
    }

    @Override
    public BaseExplainabilityResultDto createIntermediateResultDto(CounterfactualExplainabilityRequest request, CounterfactualResult result) {
        return this.buildResultDtoFromExplanation(request, result, CounterfactualExplainabilityResultDto.Stage.INTERMEDIATE);
    }

    private CounterfactualExplainabilityResultDto buildResultDtoFromExplanation(CounterfactualExplainabilityRequest request, CounterfactualResult result, CounterfactualExplainabilityResultDto.Stage stage) {
        List<Feature> features = result.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        List predictionOutputs = result.getOutput();
        if (Objects.isNull(predictionOutputs)) {
            throw new NullPointerException(String.format("Null Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", request.getExecutionId(), request.getCounterfactualId()));
        }
        if (predictionOutputs.isEmpty()) {
            throw new IllegalStateException(String.format("No Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", request.getExecutionId(), request.getCounterfactualId()));
        }
        if (predictionOutputs.size() > 1) {
            throw new IllegalStateException(String.format("Multiple Output sets produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", request.getExecutionId(), request.getCounterfactualId()));
        }
        List outputs = ((PredictionOutput)predictionOutputs.get(0)).getOutputs();
        return CounterfactualExplainabilityResultDto.buildSucceeded((String)request.getExecutionId(), (String)request.getCounterfactualId(), (String)result.getSolutionId().toString(), (Boolean)result.isValid(), (CounterfactualExplainabilityResultDto.Stage)stage, ConversionUtils.fromFeatureList(features), ConversionUtils.fromOutputs(outputs));
    }

    public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider predictionProvider, Consumer<CounterfactualResult> intermediateResultsConsumer) {
        return this.explainer.explainAsync(prediction, predictionProvider, intermediateResultsConsumer);
    }
}

