/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.counterfactual;

import java.math.BigDecimal;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualSolution;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CounterFactualScoreCalculator
implements EasyScoreCalculator<CounterfactualSolution, BendableBigDecimalScore> {
    private static final Logger logger = LoggerFactory.getLogger(CounterFactualScoreCalculator.class);

    public static Double outputDistance(Output prediction, Output goal) throws IllegalArgumentException {
        return CounterFactualScoreCalculator.outputDistance(prediction, goal, 0.0);
    }

    public static Double outputDistance(Output prediction, Output goal, double threshold) throws IllegalArgumentException {
        Type goalType;
        Type predictionType = prediction.getType();
        if (predictionType != (goalType = goal.getType())) {
            String message = String.format("Features must have the same type. Feature '%s', has type '%s' and '%s'", prediction.getName(), predictionType.toString(), goalType.toString());
            logger.error(message);
            throw new IllegalArgumentException(message);
        }
        if (prediction.getType() == Type.NUMBER) {
            double predictionValue = prediction.getValue().asNumber();
            double goalValue = goal.getValue().asNumber();
            double difference = Math.abs(predictionValue - goalValue);
            double distance = predictionValue == 0.0 || goalValue == 0.0 ? difference : difference / Math.max(predictionValue, goalValue);
            if (distance < threshold) {
                return 0.0;
            }
            return distance;
        }
        if (prediction.getType() == Type.CATEGORICAL || prediction.getType() == Type.BOOLEAN || prediction.getType() == Type.TEXT) {
            return prediction.getValue().getUnderlyingObject().equals(goal.getValue().getUnderlyingObject()) ? 0.0 : 1.0;
        }
        String message = String.format("Feature '%s' has unsupported type '%s'", prediction.getName(), predictionType.toString());
        logger.error(message);
        throw new IllegalArgumentException(message);
    }

    public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
        double primaryHardScore = 0.0;
        int secondaryHardScore = 0;
        int tertiaryHardScore = 0;
        int secondarySoftscore = 0;
        StringBuilder builder = new StringBuilder();
        double inputSimilarities = 0.0;
        int numberOfEntities = solution.getEntities().size();
        for (CounterfactualEntity entity : solution.getEntities()) {
            double entitySimilarity = entity.similarity();
            inputSimilarities += entitySimilarity / (double)numberOfEntities;
            Feature f = entity.asFeature();
            builder.append(String.format("%s=%s (d:%f)", f.getName(), f.getValue().getUnderlyingObject(), entitySimilarity));
            if (!entity.isChanged()) continue;
            --secondarySoftscore;
            if (!entity.isConstrained()) continue;
            --secondaryHardScore;
        }
        double primarySoftScore = -Math.sqrt(1.0 - inputSimilarities);
        logger.debug("Current solution: {}", (Object)builder);
        List<Feature> input = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        PredictionInput predictionInput = new PredictionInput(input);
        List<PredictionInput> inputs = List.of(predictionInput);
        CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
        List<Output> goal = solution.getGoal();
        try {
            List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            solution.setPredictionOutputs(predictions);
            double outputDistance = 0.0;
            for (PredictionOutput predictionOutput : predictions) {
                List<Output> outputs = predictionOutput.getOutputs();
                if (goal.size() != outputs.size()) {
                    throw new IllegalArgumentException("Prediction size must be equal to goal size");
                }
                int numberOutputs = outputs.size();
                for (int i = 0; i < numberOutputs; ++i) {
                    Output output = outputs.get(i);
                    Output goalOutput = goal.get(i);
                    double d = CounterFactualScoreCalculator.outputDistance(output, goalOutput, solution.getGoalThreshold());
                    outputDistance += d * d;
                    if (!(output.getScore() < goalOutput.getScore())) continue;
                    --tertiaryHardScore;
                }
                logger.debug("Distance penalty: {}", (Object)(primaryHardScore -= Math.sqrt(outputDistance)));
                logger.debug("Changed constraints penalty: {}", (Object)secondaryHardScore);
                logger.debug("Confidence threshold penalty: {}", (Object)tertiaryHardScore);
            }
        }
        catch (ExecutionException e) {
            logger.error("Prediction returned an error {}", (Object)e.getMessage());
        }
        catch (InterruptedException e) {
            logger.error("Interrupted while waiting for prediction {}", (Object)e.getMessage());
            Thread.currentThread().interrupt();
        }
        catch (TimeoutException e) {
            logger.error("Timed out while waiting for prediction");
        }
        logger.debug("Feature distance: {}", (Object)(-Math.abs(primarySoftScore)));
        return BendableBigDecimalScore.of((BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(primaryHardScore), BigDecimal.valueOf(secondaryHardScore), BigDecimal.valueOf(tertiaryHardScore)}, (BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(-Math.abs(primarySoftScore)), BigDecimal.valueOf(secondarySoftscore)});
    }
}

