/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.LocalExplanationException;
import org.kie.kogito.explainability.local.lime.DatasetEncoder;
import org.kie.kogito.explainability.local.lime.DatasetNotSeparableException;
import org.kie.kogito.explainability.local.lime.SampleWeighter;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.LinearModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LimeExplainer
implements LocalExplainer<Saliency> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LimeExplainer.class);
    private static final double SEPARABLE_DATASET_RATIO = 0.99;
    private final int noOfSamples;
    private final int noOfPerturbations;
    private final int noOfRetries;

    public LimeExplainer(int noOfSamples, int noOfPerturbations, int noOfRetries) {
        this.noOfSamples = noOfSamples;
        this.noOfPerturbations = noOfPerturbations;
        this.noOfRetries = noOfRetries;
    }

    public LimeExplainer(int noOfSamples, int noOfPerturbations) {
        this.noOfSamples = noOfSamples;
        this.noOfPerturbations = noOfPerturbations;
        this.noOfRetries = 3;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public Saliency explain(Prediction prediction, PredictionProvider model) {
        long start = System.currentTimeMillis();
        LinkedList<FeatureImportance> saliencies = new LinkedList<FeatureImportance>();
        PredictionInput originalInput = prediction.getInput();
        List<Feature> inputFeatures = originalInput.getFeatures();
        if (inputFeatures.size() <= 0) throw new LocalExplanationException("cannot explain a prediction whose input is empty");
        List<PredictionInput> linearizedInputs = DataUtils.linearizeInputs(List.of(originalInput));
        if (linearizedInputs.size() <= 0) throw new LocalExplanationException("input features linearization failed");
        PredictionInput targetInput = linearizedInputs.get(0);
        List<Feature> linearizedTargetInputFeatures = targetInput.getFeatures();
        List<Output> actualOutputs = prediction.getOutput().getOutputs();
        int noOfInputFeatures = inputFeatures.size();
        int noOfOutputFeatures = linearizedTargetInputFeatures.size();
        double[] weights = new double[noOfOutputFeatures];
        for (int o = 0; o < actualOutputs.size(); ++o) {
            boolean separableDataset = false;
            LinkedList<PredictionInput> trainingInputs = new LinkedList<PredictionInput>();
            LinkedList<PredictionOutput> trainingOutputs = new LinkedList<PredictionOutput>();
            Output currentOutput = actualOutputs.get(o);
            if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
                Map<Double, Long> rawClassesBalance = new HashMap<Double, Long>();
                boolean classification = false;
                for (int tries = this.noOfRetries; tries > 0; --tries) {
                    Long max;
                    List<PredictionInput> perturbedInputs = this.getPerturbedInputs(originalInput, noOfInputFeatures);
                    List<PredictionOutput> perturbedOutputs = model.predict(perturbedInputs);
                    Value fv = currentOutput.getValue();
                    int finalO = o;
                    rawClassesBalance = perturbedOutputs.stream().map(p -> p.getOutputs().get(finalO)).map(output -> Type.NUMBER.equals((Object)output.getType()) ? output.getValue().asNumber() : (output.getValue().getUnderlyingObject() == null && fv.getUnderlyingObject() == null || output.getValue().getUnderlyingObject() != null && output.getValue().asString().equals(fv.asString()) ? 1.0 : 0.0)).collect(Collectors.groupingBy(Double::doubleValue, Collectors.counting()));
                    LOGGER.debug("raw samples per class: {}", (Object)rawClassesBalance);
                    if (rawClassesBalance.size() <= 1 || !((double)(max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L)).longValue() / (double)perturbedInputs.size() < 0.99)) continue;
                    separableDataset = true;
                    classification = rawClassesBalance.size() == 2;
                    trainingInputs.addAll((Collection<PredictionInput>)perturbedInputs);
                    trainingOutputs.addAll(perturbedOutputs);
                    break;
                }
                if (!separableDataset) {
                    throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
                }
                LinkedList<Output> predictedOutputs = new LinkedList<Output>();
                for (PredictionOutput trainingOutput : trainingOutputs) {
                    Output output2 = trainingOutput.getOutputs().get(o);
                    predictedOutputs.add(output2);
                }
                Output originalOutput = prediction.getOutput().getOutputs().get(o);
                DatasetEncoder datasetEncoder = new DatasetEncoder(trainingInputs, predictedOutputs, targetInput, originalOutput);
                List<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
                double[] sampleWeights = SampleWeighter.getSampleWeights(targetInput, trainingSet);
                LinearModel linearModel = new LinearModel(linearizedTargetInputFeatures.size(), classification);
                double loss = linearModel.fit(trainingSet, sampleWeights);
                if (Double.isNaN(loss)) continue;
                weights = Arrays.stream(linearModel.getWeights()).map(x -> x / (double)actualOutputs.size()).toArray();
                LOGGER.debug("weights updated for output {}", (Object)currentOutput);
                continue;
            }
            LOGGER.debug("skipping explanation of empty output {}", (Object)currentOutput);
        }
        for (int i = 0; i < weights.length; ++i) {
            FeatureImportance featureImportance = new FeatureImportance(linearizedTargetInputFeatures.get(i), weights[i]);
            saliencies.add(featureImportance);
        }
        long end = System.currentTimeMillis();
        LOGGER.debug("explanation time: {}ms", (Object)(end - start));
        return new Saliency(saliencies);
    }

    private List<PredictionInput> getPerturbedInputs(PredictionInput predictionInput, int noOfFeatures) {
        LinkedList<PredictionInput> perturbedInputs = new LinkedList<PredictionInput>();
        double perturbedDataSize = Math.max((double)this.noOfSamples, Math.pow(2.0, noOfFeatures));
        int i = 0;
        while ((double)i < perturbedDataSize) {
            perturbedInputs.add(DataUtils.perturbFeatures(predictionInput, this.noOfPerturbations));
            ++i;
        }
        return perturbedInputs;
    }
}

