/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.handlers;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.kie.kogito.explainability.ConversionUtils;
import org.kie.kogito.explainability.PredictionProviderFactory;
import org.kie.kogito.explainability.api.BaseExplainabilityRequest;
import org.kie.kogito.explainability.api.BaseExplainabilityResult;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityResult;
import org.kie.kogito.explainability.api.CounterfactualSearchDomain;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue;
import org.kie.kogito.explainability.api.HasNameValue;
import org.kie.kogito.explainability.api.NamedTypedValue;
import org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.tracing.typedvalue.CollectionValue;
import org.kie.kogito.tracing.typedvalue.StructureValue;
import org.kie.kogito.tracing.typedvalue.TypedValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ApplicationScoped
public class CounterfactualExplainerServiceHandler
implements LocalExplainerServiceHandler<CounterfactualResult, CounterfactualExplainabilityRequest> {
    private static final Logger LOGGER = LoggerFactory.getLogger(CounterfactualExplainerServiceHandler.class);
    private final Long kafkaMaxRecordAgeSeconds;
    private final CounterfactualExplainer explainer;
    private final PredictionProviderFactory predictionProviderFactory;

    @Inject
    public CounterfactualExplainerServiceHandler(CounterfactualExplainer explainer, PredictionProviderFactory predictionProviderFactory, @ConfigProperty(name="mp.messaging.incoming.trusty-explainability-request.throttled.unprocessed-record-max-age.ms", defaultValue="60000") Long kafkaMaxRecordAgeMilliSeconds) {
        this.explainer = explainer;
        this.predictionProviderFactory = predictionProviderFactory;
        this.kafkaMaxRecordAgeSeconds = Math.floorDiv((long)kafkaMaxRecordAgeMilliSeconds, 1000);
    }

    @Override
    public <T extends BaseExplainabilityRequest> boolean supports(Class<T> type) {
        return CounterfactualExplainabilityRequest.class.isAssignableFrom(type);
    }

    @Override
    public PredictionProvider getPredictionProvider(CounterfactualExplainabilityRequest request) {
        return this.predictionProviderFactory.createPredictionProvider(request.getServiceUrl(), request.getModelIdentifier(), request.getGoals());
    }

    @Override
    public Prediction getPrediction(CounterfactualExplainabilityRequest request) {
        List<NamedTypedValue> goals = this.toMapBasedSorting(request.getGoals());
        Collection searchDomains = request.getSearchDomains();
        Collection originalInputs = request.getOriginalInputs();
        Long maxRunningTimeSeconds = request.getMaxRunningTimeSeconds();
        if (Objects.nonNull(maxRunningTimeSeconds) && maxRunningTimeSeconds > this.kafkaMaxRecordAgeSeconds) {
            LOGGER.info(String.format("Maximum Running Timeout set to '%d's since the provided value '%d's exceeded the Messaging sub-system configuration '%d's.", this.kafkaMaxRecordAgeSeconds, maxRunningTimeSeconds, this.kafkaMaxRecordAgeSeconds));
            maxRunningTimeSeconds = this.kafkaMaxRecordAgeSeconds;
        }
        if (this.isUnsupportedModel(originalInputs, goals, searchDomains)) {
            throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
        }
        PredictionInput input = new PredictionInput(ConversionUtils.toFeatureList(originalInputs, searchDomains));
        PredictionOutput output = new PredictionOutput(ConversionUtils.toOutputList(goals));
        return new CounterfactualPrediction(input, output, null, UUID.fromString(request.getExecutionId()), maxRunningTimeSeconds);
    }

    private boolean isUnsupportedModel(Collection<NamedTypedValue> originalInputs, Collection<NamedTypedValue> goals, Collection<CounterfactualSearchDomain> searchDomains) {
        return this.isUnsupportedTypedValue(originalInputs) || this.isUnsupportedTypedValue(goals) || this.isUnsupportedCounterfactualSearchDomain(searchDomains);
    }

    private boolean isUnsupportedTypedValue(Collection<? extends HasNameValue<?>> values) {
        return values.stream().map(HasNameValue::getValue).anyMatch(tv -> tv instanceof StructureValue || tv instanceof CollectionValue);
    }

    private boolean isUnsupportedCounterfactualSearchDomain(Collection<CounterfactualSearchDomain> domains) {
        return domains.stream().map(CounterfactualSearchDomain::getValue).anyMatch(domain -> domain instanceof CounterfactualSearchDomainStructureValue || domain instanceof CounterfactualSearchDomainCollectionValue);
    }

    private List<NamedTypedValue> toMapBasedSorting(Collection<NamedTypedValue> goals) {
        Map goalsMap = goals != null ? (Map)goals.stream().collect(HashMap::new, (m, v) -> m.put(v.getName(), v.getValue()), HashMap::putAll) : Collections.emptyMap();
        return goalsMap.entrySet().stream().map(e -> new NamedTypedValue((String)e.getKey(), (TypedValue)e.getValue())).collect(Collectors.toList());
    }

    @Override
    public BaseExplainabilityResult createSucceededResult(CounterfactualExplainabilityRequest request, CounterfactualResult result) {
        return this.buildResultFromExplanation(request, result, CounterfactualExplainabilityResult.Stage.FINAL);
    }

    @Override
    public BaseExplainabilityResult createFailedResult(CounterfactualExplainabilityRequest request, Throwable throwable) {
        return CounterfactualExplainabilityResult.buildFailed((String)request.getExecutionId(), (String)request.getCounterfactualId(), (String)throwable.getMessage());
    }

    @Override
    public BaseExplainabilityResult createIntermediateResult(CounterfactualExplainabilityRequest request, CounterfactualResult result) {
        return this.buildResultFromExplanation(request, result, CounterfactualExplainabilityResult.Stage.INTERMEDIATE);
    }

    private CounterfactualExplainabilityResult buildResultFromExplanation(CounterfactualExplainabilityRequest request, CounterfactualResult result, CounterfactualExplainabilityResult.Stage stage) {
        List<Feature> features = result.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        List predictionOutputs = result.getOutput();
        if (Objects.isNull(predictionOutputs)) {
            throw new NullPointerException(String.format("Null Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", request.getExecutionId(), request.getCounterfactualId()));
        }
        if (predictionOutputs.isEmpty()) {
            throw new IllegalStateException(String.format("No Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", request.getExecutionId(), request.getCounterfactualId()));
        }
        if (predictionOutputs.size() > 1) {
            throw new IllegalStateException(String.format("Multiple Output sets produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", request.getExecutionId(), request.getCounterfactualId()));
        }
        List outputs = ((PredictionOutput)predictionOutputs.get(0)).getOutputs();
        return CounterfactualExplainabilityResult.buildSucceeded((String)request.getExecutionId(), (String)request.getCounterfactualId(), (String)result.getSolutionId().toString(), (Long)result.getSequenceId(), (Boolean)result.isValid(), (CounterfactualExplainabilityResult.Stage)stage, ConversionUtils.fromFeatureList(features), ConversionUtils.fromOutputs(outputs));
    }

    public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider predictionProvider, Consumer<CounterfactualResult> intermediateResultsConsumer) {
        return this.explainer.explainAsync(prediction, predictionProvider, intermediateResultsConsumer);
    }
}

