/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime.optim;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.local.lime.optim.LimeConfigEntityFactory;
import org.kie.kogito.explainability.local.lime.optim.LimeStabilitySolution;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.LocalSaliencyStability;
import org.optaplanner.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore;
import org.optaplanner.core.api.score.calculator.EasyScoreCalculator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LimeStabilityScoreCalculator
implements EasyScoreCalculator<LimeStabilitySolution, SimpleBigDecimalScore> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LimeStabilityScoreCalculator.class);
    private static final BigDecimal TWO = BigDecimal.valueOf(2.0);
    private static final BigDecimal ZERO = BigDecimal.valueOf(0L);

    public SimpleBigDecimalScore calculateScore(LimeStabilitySolution solution) {
        LimeConfig config = LimeConfigEntityFactory.toLimeConfig(solution);
        BigDecimal stabilityScore = BigDecimal.ZERO;
        List<Prediction> predictions = solution.getPredictions();
        if (!predictions.isEmpty()) {
            stabilityScore = this.getStabilityScore(solution, config, predictions);
        }
        return SimpleBigDecimalScore.of((BigDecimal)stabilityScore);
    }

    private BigDecimal getStabilityScore(LimeStabilitySolution solution, LimeConfig config, List<Prediction> predictions) {
        double succeededEvaluations = 0.0;
        BigDecimal stabilityScore = BigDecimal.ZERO;
        LimeExplainer limeExplainer = new LimeExplainer(config);
        for (Prediction prediction : predictions) {
            try {
                LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(solution.getModel(), prediction, limeExplainer, TWO.intValue(), 5);
                for (String decision : stability.getDecisions()) {
                    BigDecimal decisionMarginalScore = this.getDecisionMarginalScore(TWO, stability, decision);
                    stabilityScore = stabilityScore.add(decisionMarginalScore);
                    succeededEvaluations += 1.0;
                }
            }
            catch (ExecutionException e) {
                LOGGER.error("Saliency stability calculation returned an error {}", (Object)e.getMessage());
            }
            catch (InterruptedException e) {
                LOGGER.error("Interrupted while waiting for saliency stability calculation {}", (Object)e.getMessage());
                Thread.currentThread().interrupt();
            }
            catch (TimeoutException e) {
                LOGGER.error("Timed out while waiting for saliency stability calculation", (Throwable)e);
            }
        }
        if (succeededEvaluations > 0.0) {
            stabilityScore = stabilityScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
        }
        return stabilityScore;
    }

    private BigDecimal getDecisionMarginalScore(BigDecimal topK, LocalSaliencyStability stability, String decision) {
        BigDecimal positiveStabilityScore = ZERO;
        BigDecimal negativeStabilityScore = ZERO;
        for (int i = 1; i <= topK.intValue(); ++i) {
            positiveStabilityScore = positiveStabilityScore.add(BigDecimal.valueOf(stability.getPositiveStabilityScore(decision, i)));
            negativeStabilityScore = negativeStabilityScore.add(BigDecimal.valueOf(stability.getNegativeStabilityScore(decision, i)));
        }
        positiveStabilityScore = positiveStabilityScore.divide(topK, RoundingMode.CEILING);
        negativeStabilityScore = negativeStabilityScore.divide(topK, RoundingMode.CEILING);
        return positiveStabilityScore.add(negativeStabilityScore).divide(TWO.multiply(BigDecimal.valueOf(stability.getDecisions().size())), RoundingMode.CEILING);
    }
}

