/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.counterfactual;

import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualSolution;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionFeatureDomain;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CounterfactualExplainer
implements LocalExplainer<CounterfactualResult> {
    public static final Consumer<CounterfactualSolution> assignSolutionId = counterfactual -> counterfactual.setSolutionId(UUID.randomUUID());
    private static final Logger logger = LoggerFactory.getLogger(CounterfactualExplainer.class);
    private final CounterfactualConfig counterfactualConfig;

    public CounterfactualExplainer() {
        this.counterfactualConfig = new CounterfactualConfig();
    }

    public CounterfactualExplainer(CounterfactualConfig counterfactualConfig) {
        this.counterfactualConfig = counterfactualConfig;
    }

    public CounterfactualConfig getCounterfactualConfig() {
        return this.counterfactualConfig;
    }

    private Consumer<CounterfactualSolution> createSolutionConsumer(Consumer<CounterfactualResult> consumer, AtomicLong sequenceId) {
        return counterfactualSolution -> {
            if (counterfactualSolution.getScore().isFeasible()) {
                CounterfactualResult result = new CounterfactualResult(counterfactualSolution.getEntities(), counterfactualSolution.getPredictionOutputs(), counterfactualSolution.getScore().isFeasible(), counterfactualSolution.getSolutionId(), counterfactualSolution.getExecutionId(), sequenceId.incrementAndGet());
                consumer.accept(result);
            }
        };
    }

    @Override
    public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
        AtomicLong sequenceId = new AtomicLong(0L);
        CounterfactualPrediction cfPrediction = (CounterfactualPrediction)prediction;
        PredictionFeatureDomain featureDomain = cfPrediction.getDomain();
        List<Boolean> constraints = cfPrediction.getConstraints();
        UUID executionId = cfPrediction.getExecutionId();
        Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
        List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput(), featureDomain, constraints, cfPrediction.getDataDistribution());
        List<Output> goal = prediction.getOutput().getOutputs();
        Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
        CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
            /*
             * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
             * 
             * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [4[CATCHBLOCK]], but top level block is 2[TRYBLOCK]
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
             *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
             *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1050)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
             *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
             *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
             *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
             *     at org.benf.cfr.reader.Main.main(Main.java:54)
             */
            throw new IllegalStateException("Decompilation failed");
        }, this.counterfactualConfig.getExecutor());
        CompletionStage cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(List.of(new PredictionInput(s.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList())))));
        return CompletableFuture.allOf(new CompletableFuture[]{cfOutputs, cfSolution}).thenApply(arg_0 -> CounterfactualExplainer.lambda$explainAsync$5(cfSolution, (CompletableFuture)cfOutputs, sequenceId, arg_0));
    }

    private static /* synthetic */ CounterfactualResult lambda$explainAsync$5(CompletableFuture cfSolution, CompletableFuture cfOutputs, AtomicLong sequenceId, Void v) {
        CounterfactualSolution solution = (CounterfactualSolution)cfSolution.join();
        return new CounterfactualResult(solution.getEntities(), (List)cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
    }
}

