/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.LocalExplanationException;
import org.kie.kogito.explainability.local.lime.DatasetEncoder;
import org.kie.kogito.explainability.local.lime.DatasetNotSeparableException;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeInputs;
import org.kie.kogito.explainability.local.lime.SampleWeighter;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.LinearModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LimeExplainer
implements LocalExplainer<Map<String, Saliency>> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LimeExplainer.class);
    private final LimeConfig limeConfig;

    public LimeExplainer() {
        this(new LimeConfig());
    }

    public LimeExplainer(LimeConfig limeConfig) {
        this.limeConfig = limeConfig;
    }

    public LimeConfig getLimeConfig() {
        return this.limeConfig;
    }

    @Override
    public CompletableFuture<Map<String, Saliency>> explainAsync(Prediction prediction, PredictionProvider model) {
        PredictionInput originalInput = prediction.getInput();
        if (originalInput.getFeatures().isEmpty()) {
            throw new LocalExplanationException("cannot explain a prediction whose input is empty");
        }
        List<PredictionInput> linearizedInputs = DataUtils.linearizeInputs(List.of(originalInput));
        PredictionInput targetInput = linearizedInputs.get(0);
        List<Feature> linearizedTargetInputFeatures = targetInput.getFeatures();
        if (linearizedTargetInputFeatures.isEmpty()) {
            throw new LocalExplanationException("input features linearization failed");
        }
        List<Output> actualOutputs = prediction.getOutput().getOutputs();
        return this.explainRetryCycle(model, originalInput, linearizedTargetInputFeatures, actualOutputs, this.limeConfig.getNoOfRetries(), this.limeConfig.getNoOfSamples(), this.limeConfig.getPerturbationContext());
    }

    protected CompletableFuture<Map<String, Saliency>> explainRetryCycle(PredictionProvider model, PredictionInput originalInput, List<Feature> linearizedTargetInputFeatures, List<Output> actualOutputs, int noOfRetries, int noOfSamples, PerturbationContext perturbationContext) {
        List<PredictionInput> perturbedInputs = this.getPerturbedInputs(originalInput.getFeatures(), perturbationContext);
        return model.predictAsync(perturbedInputs).thenCompose(predictionOutputs -> {
            try {
                boolean strict = noOfRetries > 0;
                List<LimeInputs> limeInputsList = this.getLimeInputs(linearizedTargetInputFeatures, actualOutputs, perturbedInputs, (List<PredictionOutput>)predictionOutputs, strict);
                return CompletableFuture.completedFuture(this.getSaliencies(linearizedTargetInputFeatures, actualOutputs, limeInputsList));
            }
            catch (DatasetNotSeparableException e) {
                if (noOfRetries > 0) {
                    int newNoOfSamples;
                    PerturbationContext newPerturbationContext;
                    if (this.limeConfig.adaptDatasetVariance()) {
                        int nextPerturbationSize = Math.max(perturbationContext.getNoOfPerturbations() + 1, linearizedTargetInputFeatures.size() / noOfRetries);
                        nextPerturbationSize = Math.min(linearizedTargetInputFeatures.size() - 1, nextPerturbationSize);
                        newPerturbationContext = new PerturbationContext(perturbationContext.getRandom(), nextPerturbationSize);
                        newNoOfSamples = noOfSamples + this.limeConfig.getNoOfSamples() / this.limeConfig.getNoOfRetries();
                    } else {
                        newPerturbationContext = perturbationContext;
                        newNoOfSamples = noOfSamples;
                    }
                    return this.explainRetryCycle(model, originalInput, linearizedTargetInputFeatures, actualOutputs, noOfRetries - 1, newNoOfSamples, newPerturbationContext);
                }
                throw e;
            }
        });
    }

    private List<LimeInputs> getLimeInputs(List<Feature> linearizedTargetInputFeatures, List<Output> actualOutputs, List<PredictionInput> perturbedInputs, List<PredictionOutput> predictionOutputs, boolean strict) {
        LinkedList<LimeInputs> limeInputsList = new LinkedList<LimeInputs>();
        for (int o = 0; o < actualOutputs.size(); ++o) {
            Output currentOutput = actualOutputs.get(o);
            LimeInputs limeInputs = this.prepareInputs(perturbedInputs, predictionOutputs, linearizedTargetInputFeatures, o, currentOutput, strict);
            limeInputsList.add(limeInputs);
        }
        return limeInputsList;
    }

    private Map<String, Saliency> getSaliencies(List<Feature> linearizedTargetInputFeatures, List<Output> actualOutputs, List<LimeInputs> limeInputsList) {
        HashMap<String, Saliency> result = new HashMap<String, Saliency>();
        for (int o = 0; o < actualOutputs.size(); ++o) {
            LimeInputs limeInputs = limeInputsList.get(o);
            Output originalOutput = actualOutputs.get(o);
            this.getSaliency(linearizedTargetInputFeatures, result, limeInputs, originalOutput);
            LOGGER.debug("weights set for output {}", (Object)originalOutput);
        }
        return result;
    }

    private void getSaliency(List<Feature> linearizedTargetInputFeatures, Map<String, Saliency> result, LimeInputs limeInputs, Output originalOutput) {
        LinkedList<FeatureImportance> featureImportanceList = new LinkedList<FeatureImportance>();
        DatasetEncoder datasetEncoder = new DatasetEncoder(limeInputs.getPerturbedInputs(), limeInputs.getPerturbedOutputs(), linearizedTargetInputFeatures, originalOutput);
        List<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
        double[] sampleWeights = SampleWeighter.getSampleWeights(linearizedTargetInputFeatures, trainingSet);
        LinearModel linearModel = new LinearModel(linearizedTargetInputFeatures.size(), limeInputs.isClassification());
        double loss = linearModel.fit(trainingSet, sampleWeights);
        if (!Double.isNaN(loss)) {
            int i = 0;
            for (Feature linearizedFeature : linearizedTargetInputFeatures) {
                FeatureImportance featureImportance = new FeatureImportance(linearizedFeature, linearModel.getWeights()[i]);
                featureImportanceList.add(featureImportance);
                ++i;
            }
        }
        Saliency saliency = new Saliency(originalOutput, featureImportanceList);
        result.put(originalOutput.getName(), saliency);
    }

    private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
        if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
            boolean classification;
            Value fv = currentOutput.getValue();
            Map<Double, Long> rawClassesBalance = this.getClassBalance(perturbedOutputs, fv, o);
            Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
            double separationRatio = (double)max.longValue() / (double)perturbedInputs.size();
            List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
            boolean bl = classification = rawClassesBalance.size() == 2;
            if (strict) {
                if (rawClassesBalance.size() > 1 && separationRatio < this.limeConfig.getSeparableDatasetRatio()) {
                    return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
                }
                throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
            }
            LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", new Object[]{currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance});
            return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
        }
        return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, Collections.emptyList(), Collections.emptyList());
    }

    private Map<Double, Long> getClassBalance(List<PredictionOutput> perturbedOutputs, Value<?> fv, int finalO) {
        Map<Double, Long> rawClassesBalance = perturbedOutputs.stream().map(p -> p.getOutputs().get(finalO)).map(output -> Type.NUMBER.equals((Object)output.getType()) ? output.getValue().asNumber() : (output.getValue().getUnderlyingObject() == null && fv.getUnderlyingObject() == null || output.getValue().getUnderlyingObject() != null && output.getValue().asString().equals(fv.asString()) ? 1.0 : 0.0)).collect(Collectors.groupingBy(Double::doubleValue, Collectors.counting()));
        LOGGER.debug("raw samples per class: {}", rawClassesBalance);
        return rawClassesBalance;
    }

    private List<PredictionInput> getPerturbedInputs(List<Feature> features, PerturbationContext perturbationContext) {
        LinkedList<PredictionInput> perturbedInputs = new LinkedList<PredictionInput>();
        double perturbedDataSize = Math.max((double)this.limeConfig.getNoOfSamples(), Math.pow(2.0, features.size()));
        int i = 0;
        while ((double)i < perturbedDataSize) {
            List<Feature> newFeatures = DataUtils.perturbFeatures(features, perturbationContext);
            perturbedInputs.add(new PredictionInput(newFeatures));
            ++i;
        }
        return perturbedInputs;
    }
}

