/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.counterfactual;

import java.math.BigDecimal;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualSolution;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CounterFactualScoreCalculator
implements EasyScoreCalculator<CounterfactualSolution, BendableBigDecimalScore> {
    private static final Logger logger = LoggerFactory.getLogger(CounterFactualScoreCalculator.class);

    public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
        double primaryHardScore = 0.0;
        int secondaryHardScore = 0;
        int tertiaryHardScore = 0;
        double primarySoftScore = 0.0;
        int secondarySoftscore = 0;
        StringBuilder builder = new StringBuilder();
        for (CounterfactualEntity entity : solution.getEntities()) {
            double entityDistance = entity.distance();
            primarySoftScore += entityDistance;
            Feature f = entity.asFeature();
            builder.append(String.format("%s=%s (d:%f)", f.getName(), f.getValue().getUnderlyingObject(), entityDistance));
            if (!entity.isChanged()) continue;
            --secondarySoftscore;
            if (!entity.isConstrained()) continue;
            --secondaryHardScore;
        }
        logger.debug("Current solution: {}", (Object)builder);
        List<Feature> input = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        PredictionInput predictionInput = new PredictionInput(input);
        List<PredictionInput> inputs = List.of(predictionInput);
        CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
        List<Output> goal = solution.getGoal();
        try {
            List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            solution.setPredictionOutputs(predictions);
            double distance = 0.0;
            for (PredictionOutput predictionOutput : predictions) {
                List<Output> outputs = predictionOutput.getOutputs();
                if (outputs.size() != predictions.size()) {
                    throw new IllegalArgumentException("Prediction size must be equal to goal size");
                }
                for (int i = 0; i < outputs.size(); ++i) {
                    Output output = outputs.get(i);
                    Output goalOutput = goal.get(i);
                    double d = goalOutput.getValue().asNumber() - output.getValue().asNumber();
                    distance += d * d;
                    if (!(output.getScore() < goalOutput.getScore())) continue;
                    --tertiaryHardScore;
                }
                logger.debug("Distance penalty: {}", (Object)(primaryHardScore -= Math.sqrt(distance)));
                logger.debug("Changed constraints penalty: {}", (Object)secondaryHardScore);
                logger.debug("Confidence threshold penalty: {}", (Object)tertiaryHardScore);
            }
        }
        catch (ExecutionException e) {
            logger.error("Prediction returned an error {}", (Object)e.getMessage());
        }
        catch (InterruptedException e) {
            logger.error("Interrupted while waiting for prediction {}", (Object)e.getMessage());
            Thread.currentThread().interrupt();
        }
        catch (TimeoutException e) {
            logger.error("Timed out while waiting for prediction");
        }
        logger.debug("Feature distance: {}", (Object)(-Math.abs(primarySoftScore)));
        return BendableBigDecimalScore.of((BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(primaryHardScore), BigDecimal.valueOf(secondaryHardScore), BigDecimal.valueOf(tertiaryHardScore)}, (BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(-Math.abs(primarySoftScore)), BigDecimal.valueOf(secondarySoftscore)});
    }
}

