/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.counterfactual;

import java.math.BigDecimal;
import java.time.Duration;
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualSolution;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.CompositeFeatureUtils;
import org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CounterFactualScoreCalculator
implements EasyScoreCalculator<CounterfactualSolution, BendableBigDecimalScore> {
    private static final double DEFAULT_DISTANCE = 1.0;
    private static final Logger logger = LoggerFactory.getLogger(CounterFactualScoreCalculator.class);
    private static final Set<Type> SUPPORTED_CATEGORICAL_TYPES = Set.of(Type.CATEGORICAL, Type.BOOLEAN, Type.TEXT, Type.CURRENCY, Type.BINARY, Type.UNDEFINED);

    public static Double outputDistance(Output prediction, Output goal) throws IllegalArgumentException {
        return CounterFactualScoreCalculator.outputDistance(prediction, goal, 0.0);
    }

    public static Double outputDistance(Output prediction, Output goal, double threshold) throws IllegalArgumentException {
        Type goalType;
        Type predictionType = prediction.getType();
        if (predictionType != (goalType = goal.getType())) {
            if (Objects.nonNull(prediction.getValue().getUnderlyingObject())) {
                String message = String.format("Features must have the same type. Feature '%s', has type '%s' and '%s'", prediction.getName(), predictionType.toString(), goalType.toString());
                logger.error(message);
                throw new IllegalArgumentException(message);
            }
            return 1.0;
        }
        if (predictionType == Type.NUMBER) {
            double predictionValue = prediction.getValue().asNumber();
            double goalValue = goal.getValue().asNumber();
            double difference = Math.abs(predictionValue - goalValue);
            if (Double.isNaN(predictionValue) || Double.isNaN(goalValue)) {
                String message = String.format("Unsupported NaN or NULL for numeric feature '%s'", prediction.getName());
                logger.error(message);
                throw new IllegalArgumentException(message);
            }
            double distance = predictionValue == 0.0 || goalValue == 0.0 ? difference : difference / Math.max(predictionValue, goalValue);
            if (distance < threshold) {
                return 0.0;
            }
            return distance;
        }
        if (predictionType == Type.DURATION) {
            Duration predictionValue = (Duration)prediction.getValue().getUnderlyingObject();
            Duration goalValue = (Duration)goal.getValue().getUnderlyingObject();
            if (Objects.isNull(predictionValue) || Objects.isNull(goalValue)) {
                return 1.0;
            }
            double difference = predictionValue.minus(goalValue).abs().getSeconds();
            double distance = predictionValue.isZero() || goalValue.isZero() ? difference : difference / (double)Math.max(predictionValue.getSeconds(), goalValue.getSeconds());
            if (distance < threshold) {
                return 0.0;
            }
            return distance;
        }
        if (predictionType == Type.TIME) {
            LocalTime predictionValue = (LocalTime)prediction.getValue().getUnderlyingObject();
            LocalTime goalValue = (LocalTime)goal.getValue().getUnderlyingObject();
            if (Objects.isNull(predictionValue) || Objects.isNull(goalValue)) {
                return 1.0;
            }
            double interval = LocalTime.MIN.until(LocalTime.MAX, ChronoUnit.SECONDS);
            double distance = (double)Math.abs(predictionValue.until(goalValue, ChronoUnit.SECONDS)) / interval;
            if (distance < threshold) {
                return 0.0;
            }
            return distance;
        }
        if (SUPPORTED_CATEGORICAL_TYPES.contains((Object)predictionType)) {
            Object predictionValueObject;
            Object goalValueObject = goal.getValue().getUnderlyingObject();
            return Objects.equals(goalValueObject, predictionValueObject = prediction.getValue().getUnderlyingObject()) ? 0.0 : 1.0;
        }
        String message = String.format("Feature '%s' has unsupported type '%s'", prediction.getName(), predictionType.toString());
        logger.error(message);
        throw new IllegalArgumentException(message);
    }

    private BendableBigDecimalScore calculateInputScore(CounterfactualSolution solution) {
        StringBuilder builder = new StringBuilder();
        int secondarySoftScore = 0;
        int secondaryHardScore = 0;
        double inputSimilarities = 0.0;
        int numberOfEntities = solution.getEntities().size();
        for (CounterfactualEntity entity : solution.getEntities()) {
            double entitySimilarity = entity.similarity();
            inputSimilarities += entitySimilarity / (double)numberOfEntities;
            Feature f = entity.asFeature();
            builder.append(String.format("%s=%s (d:%f)", f.getName(), f.getValue().getUnderlyingObject(), entitySimilarity));
            if (!entity.isChanged()) continue;
            --secondarySoftScore;
            if (!entity.isConstrained()) continue;
            --secondaryHardScore;
        }
        logger.debug("Current solution: {}", (Object)builder);
        double primarySoftScore = -Math.sqrt(Math.abs(1.0 - inputSimilarities));
        logger.debug("Changed constraints penalty: {}", (Object)secondaryHardScore);
        logger.debug("Feature distance: {}", (Object)(-Math.abs(primarySoftScore)));
        return BendableBigDecimalScore.of((BigDecimal[])new BigDecimal[]{BigDecimal.ZERO, BigDecimal.valueOf(secondaryHardScore), BigDecimal.ZERO}, (BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(-Math.abs(primarySoftScore)), BigDecimal.valueOf(secondarySoftScore)});
    }

    private BendableBigDecimalScore calculateOutputScore(CounterfactualSolution solution) {
        List<PredictionOutput> predictions = solution.getPredictionOutputs();
        List<Output> goal = solution.getGoal();
        double outputDistance = 0.0;
        int tertiaryHardScore = 0;
        double primaryHardScore = 0.0;
        for (PredictionOutput predictionOutput : predictions) {
            List<Output> outputs = predictionOutput.getOutputs();
            if (goal.size() != outputs.size()) {
                throw new IllegalArgumentException("Prediction size must be equal to goal size");
            }
            int numberOutputs = outputs.size();
            for (int i = 0; i < numberOutputs; ++i) {
                Output output = outputs.get(i);
                Output goalOutput = goal.get(i);
                double d = CounterFactualScoreCalculator.outputDistance(output, goalOutput, solution.getGoalThreshold());
                outputDistance += d * d;
                if (!(output.getScore() < goalOutput.getScore())) continue;
                --tertiaryHardScore;
            }
            logger.debug("Distance penalty: {}", (Object)(primaryHardScore -= Math.sqrt(outputDistance)));
            logger.debug("Confidence threshold penalty: {}", (Object)tertiaryHardScore);
        }
        return BendableBigDecimalScore.of((BigDecimal[])new BigDecimal[]{BigDecimal.valueOf(primaryHardScore), BigDecimal.ZERO, BigDecimal.valueOf(tertiaryHardScore)}, (BigDecimal[])new BigDecimal[]{BigDecimal.ZERO, BigDecimal.ZERO});
    }

    public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
        BendableBigDecimalScore currentScore = this.calculateInputScore(solution);
        List<Feature> flattenedFeatures = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
        List<Feature> input = CompositeFeatureUtils.unflattenFeatures(flattenedFeatures, solution.getOriginalFeatures());
        List<PredictionInput> inputs = List.of(new PredictionInput(input));
        CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
        try {
            List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            solution.setPredictionOutputs(predictions);
            BendableBigDecimalScore outputScore = this.calculateOutputScore(solution);
            currentScore = currentScore.add(outputScore);
        }
        catch (ExecutionException e) {
            logger.error("Prediction returned an error {}", (Object)e.getMessage());
        }
        catch (InterruptedException e) {
            logger.error("Interrupted while waiting for prediction {}", (Object)e.getMessage());
            Thread.currentThread().interrupt();
        }
        catch (TimeoutException e) {
            logger.error("Timed out while waiting for prediction");
        }
        return currentScore;
    }
}

