/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.model.Dataset;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Value;

public class FairnessMetrics {
    private FairnessMetrics() {
    }

    public static double individualConsistency(BiFunction<PredictionInput, List<PredictionInput>, List<PredictionInput>> proximityFunction, List<PredictionInput> samples, PredictionProvider predictionProvider) throws ExecutionException, InterruptedException {
        double consistency = 1.0;
        for (PredictionInput input : samples) {
            List<PredictionOutput> predictionOutputs = predictionProvider.predictAsync(List.of(input)).get();
            PredictionOutput predictionOutput = predictionOutputs.get(0);
            List<PredictionInput> neighbors = proximityFunction.apply(input, samples);
            List<PredictionOutput> neighborsOutputs = predictionProvider.predictAsync(neighbors).get();
            for (Output output : predictionOutput.getOutputs()) {
                Value originalValue = output.getValue();
                for (PredictionOutput neighborOutput : neighborsOutputs) {
                    Output currentOutput = neighborOutput.getByName(output.getName()).orElse(null);
                    if (currentOutput == null || originalValue.equals(currentOutput.getValue())) continue;
                    consistency -= (double)(1.0f / (float)(neighbors.size() * predictionOutput.getOutputs().size() * samples.size()));
                }
            }
        }
        return consistency;
    }

    public static double groupStatisticalParityDifference(Predicate<PredictionInput> groupSelector, List<PredictionInput> samples, PredictionProvider model, Output favorableOutput) throws ExecutionException, InterruptedException {
        double probabilityUnprivileged = FairnessMetrics.getFavorableLabelProbability(groupSelector.negate(), samples, model, favorableOutput);
        double probabilityPrivileged = FairnessMetrics.getFavorableLabelProbability(groupSelector, samples, model, favorableOutput);
        return probabilityUnprivileged - probabilityPrivileged;
    }

    public static double groupDisparateImpactRatio(Predicate<PredictionInput> groupSelector, List<PredictionInput> samples, PredictionProvider model, Output favorableOutput) throws ExecutionException, InterruptedException {
        double probabilityUnprivileged = FairnessMetrics.getFavorableLabelProbability(groupSelector.negate(), samples, model, favorableOutput);
        double probabilityPrivileged = FairnessMetrics.getFavorableLabelProbability(groupSelector, samples, model, favorableOutput);
        return probabilityUnprivileged / probabilityPrivileged;
    }

    private static double getFavorableLabelProbability(Predicate<PredictionInput> groupSelector, List<PredictionInput> samples, PredictionProvider model, Output favorableOutput) throws ExecutionException, InterruptedException {
        String outputName = favorableOutput.getName();
        Value outputValue = favorableOutput.getValue();
        List<PredictionOutput> selectedOutputs = FairnessMetrics.getSelectedPredictionOutputs(groupSelector, samples, model);
        double numSelected = selectedOutputs.size();
        long numFavorableSelected = selectedOutputs.stream().map(po -> po.getByName(outputName)).map(Optional::get).filter(o -> o.getValue().equals(outputValue)).count();
        return (double)numFavorableSelected / numSelected;
    }

    private static List<PredictionOutput> getSelectedPredictionOutputs(Predicate<PredictionInput> groupSelector, List<PredictionInput> samples, PredictionProvider model) throws InterruptedException, ExecutionException {
        List<PredictionInput> selected = samples.stream().filter(groupSelector).collect(Collectors.toList());
        return model.predictAsync(selected).get();
    }

    public static double groupAverageOddsDifference(Predicate<PredictionInput> inputSelector, Predicate<PredictionOutput> outputSelector, Dataset dataset, PredictionProvider model) throws ExecutionException, InterruptedException {
        Dataset privileged = dataset.filterByInput(inputSelector);
        Map<String, Integer> privilegedCounts = FairnessMetrics.countMatchingOutputSelector(privileged, model.predictAsync(privileged.getInputs()).get(), outputSelector);
        Dataset unprivileged = dataset.filterByInput(inputSelector.negate());
        Map<String, Integer> unprivilegedCounts = FairnessMetrics.countMatchingOutputSelector(unprivileged, model.predictAsync(unprivileged.getInputs()).get(), outputSelector);
        double utp = unprivilegedCounts.get("tp").intValue();
        double utn = unprivilegedCounts.get("tn").intValue();
        double ufp = unprivilegedCounts.get("fp").intValue();
        double ufn = unprivilegedCounts.get("fn").intValue();
        double ptp = privilegedCounts.get("tp").intValue();
        double ptn = privilegedCounts.get("tn").intValue();
        double pfp = privilegedCounts.get("fp").intValue();
        double pfn = privilegedCounts.get("fn").intValue();
        return (utp / (utp + ufn) - ptp / (ptp + pfn + 1.0E-10)) / 2.0 + (ufp / (ufp + utn) - pfp / (pfp + ptn + 1.0E-10)) / 2.0;
    }

    private static Map<String, Integer> countMatchingOutputSelector(Dataset dataset, List<PredictionOutput> predictionOutputs, Predicate<PredictionOutput> outputSelector) {
        assert (predictionOutputs.size() == dataset.getData().size()) : "dataset and predictions must have same size";
        int tp = 0;
        int tn = 0;
        int fp = 0;
        int fn = 0;
        int i = 0;
        for (Prediction trainingExample : dataset.getData()) {
            if (outputSelector.test(trainingExample.getOutput())) {
                if (outputSelector.test(predictionOutputs.get(i))) {
                    ++tp;
                } else {
                    ++fn;
                }
            } else if (outputSelector.test(predictionOutputs.get(i))) {
                ++fp;
            } else {
                ++tn;
            }
            ++i;
        }
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        map.put("tp", tp);
        map.put("tn", tn);
        map.put("fp", fp);
        map.put("fn", fn);
        return map;
    }

    public static double groupAveragePredictiveValueDifference(Predicate<PredictionInput> inputSelector, Predicate<PredictionOutput> outputSelector, Dataset dataset, PredictionProvider model) throws ExecutionException, InterruptedException {
        Dataset privileged = dataset.filterByInput(inputSelector);
        Map<String, Integer> privilegedCounts = FairnessMetrics.countMatchingOutputSelector(privileged, model.predictAsync(privileged.getInputs()).get(), outputSelector);
        double ptp = privilegedCounts.get("tp").intValue();
        double ptn = privilegedCounts.get("tn").intValue();
        double pfp = privilegedCounts.get("fp").intValue();
        double pfn = privilegedCounts.get("fn").intValue();
        Dataset unprivileged = dataset.filterByInput(inputSelector.negate());
        Map<String, Integer> unprivilegedCounts = FairnessMetrics.countMatchingOutputSelector(unprivileged, model.predictAsync(unprivileged.getInputs()).get(), outputSelector);
        double utp = unprivilegedCounts.get("tp").intValue();
        double utn = unprivilegedCounts.get("tn").intValue();
        double ufp = unprivilegedCounts.get("fp").intValue();
        double ufn = unprivilegedCounts.get("fn").intValue();
        return (utp / (utp + ufp) - ptp / (ptp + pfp + 1.0E-10)) / 2.0 + (ufn / (ufn + utn) - pfn / (pfn + ptn + 1.0E-10)) / 2.0;
    }
}

