/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.counterfactual;

import java.math.BigDecimal;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualSolution;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CounterFactualScoreCalculator
implements EasyScoreCalculator<CounterfactualSolution, BendableBigDecimalScore> {
    private static final Logger logger = LoggerFactory.getLogger(CounterFactualScoreCalculator.class);

    public static Double outputDistance(Output a, Output b) throws IllegalArgumentException {
        Type bType;
        double distance = 0.0;
        Type aType = a.getType();
        if (aType != (bType = b.getType())) {
            String message = "Features must have the same type, got " + aType.toString() + " and " + bType.toString();
            logger.error(message);
            throw new IllegalArgumentException(message);
        }
        if (a.getType() == Type.NUMBER) {
            distance = a.getValue().asNumber() - b.getValue().asNumber();
        } else if (a.getType() == Type.CATEGORICAL || a.getType() == Type.BOOLEAN) {
            distance = a.getValue().getUnderlyingObject().equals(b.getValue().getUnderlyingObject()) ? 0.0 : 1.0;
        } else {
            String message = "Feature type " + aType.toString() + " not supported";
            logger.error(message);
            throw new IllegalArgumentException(message);
        }
        return distance;
    }

    public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
        double primaryHardScore = 0.0;
        int secondaryHardScore = 0;
        int tertiaryHardScore = 0;
        int secondarySoftscore = 0;
        StringBuilder builder = new StringBuilder();
        double inputSimilarities = 0.0;
        int numberOfEntities = solution.getEntities().size();
        for (CounterfactualEntity entity : solution.getEntities()) {
            double entitySimilarity = entity.similarity();
            inputSimilarities += entitySimilarity / (double)numberOfEntities;
            Feature f = entity.asFeature();
            builder.append(String.format("%s=%s (d:%f)", f.getName(), f.getValue().getUnderlyingObject(), entitySimilarity));
            if (!entity.isChanged()) continue;
            --secondarySoftscore;
            if (!entity.isConstrained()) continue;
            --secondaryHardScore;
        }
        double primarySoftScore = -Math.sqrt(1.0 - inputSimilarities);
        logger.debug("Current solution: {}", (Object)builder);
        List<Feature> input = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        PredictionInput predictionInput = new PredictionInput(input);
        List<PredictionInput> inputs = List.of(predictionInput);
        CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
        List<Output> goal = solution.getGoal();
        try {
            List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            solution.setPredictionOutputs(predictions);
            double outputDistance = 0.0;
            for (PredictionOutput predictionOutput : predictions) {
                List<Output> outputs = predictionOutput.getOutputs();
                if (goal.size() != outputs.size()) {
                    throw new IllegalArgumentException("Prediction size must be equal to goal size");
                }
                int numberOutputs = outputs.size();
                for (int i = 0; i < numberOutputs; ++i) {
                    Output output = outputs.get(i);
                    Output goalOutput = goal.get(i);
                    double d = CounterFactualScoreCalculator.outputDistance(output, goalOutput);
                    outputDistance += d * d;
                    if (!(output.getScore() < goalOutput.getScore())) continue;
                    --tertiaryHardScore;
                }
                logger.debug("Distance penalty: {}", (Object)(primaryHardScore -= Math.sqrt(outputDistance)));
                logger.debug("Changed constraints penalty: {}", (Object)secondaryHardScore);
                logger.debug("Confidence threshold penalty: {}", (Object)tertiaryHardScore);
            }
        }
        catch (ExecutionException e) {
            logger.error("Prediction returned an error {}", (Object)e.getMessage());
        }
        catch (InterruptedException e) {
            logger.error("Interrupted while waiting for prediction {}", (Object)e.getMessage());
            Thread.currentThread().interrupt();
        }
        catch (TimeoutException e) {
            logger.error("Timed out while waiting for prediction");
        }
        logger.debug("Feature distance: {}", (Object)(-Math.abs(primarySoftScore)));
        return BendableBigDecimalScore.of((BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(primaryHardScore), BigDecimal.valueOf(secondaryHardScore), BigDecimal.valueOf(tertiaryHardScore)}, (BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(-Math.abs(primarySoftScore)), BigDecimal.valueOf(secondarySoftscore)});
    }
}

