/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.global.pdp;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.global.GlobalExplainer;
import org.kie.kogito.explainability.global.pdp.PartialDependencePlotConfig;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PartialDependencePlotExplainer
implements GlobalExplainer<List<PartialDependenceGraph>> {
    private static final Logger LOGGER = LoggerFactory.getLogger(PartialDependencePlotExplainer.class);
    private final PartialDependencePlotConfig config;

    public PartialDependencePlotExplainer(PartialDependencePlotConfig config) {
        this.config = config;
    }

    public PartialDependencePlotExplainer() {
        this(new PartialDependencePlotConfig());
    }

    @Override
    public List<PartialDependenceGraph> explainFromMetadata(PredictionProvider model, PredictionProviderMetadata metadata) throws InterruptedException, ExecutionException, TimeoutException {
        return this.explainFromDataDistribution(model, metadata.getOutputShape().getOutputs().size(), metadata.getDataDistribution());
    }

    @Override
    public List<PartialDependenceGraph> explainFromPredictions(PredictionProvider model, Collection<Prediction> predictions) throws InterruptedException, ExecutionException, TimeoutException {
        int outputSize = predictions.isEmpty() ? 0 : predictions.stream().findAny().map(p -> p.getOutput().getOutputs().size()).orElse(0);
        List<PredictionInput> inputs = predictions.stream().map(Prediction::getInput).collect(Collectors.toList());
        return this.explainFromDataDistribution(model, outputSize, new PredictionInputsDataDistribution(inputs));
    }

    private List<PartialDependenceGraph> explainFromDataDistribution(PredictionProvider model, int outputSize, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
        long start = System.currentTimeMillis();
        ArrayList<PartialDependenceGraph> pdps = new ArrayList<PartialDependenceGraph>();
        List<FeatureDistribution> featureDistributions = dataDistribution.asFeatureDistributions();
        List<PredictionInput> trainingData = dataDistribution.sample(this.config.getSeriesLength());
        for (FeatureDistribution featureDistribution : featureDistributions) {
            List<Value> xsValues = featureDistribution.sample(this.config.getSeriesLength()).stream().sorted(Comparator.comparing(Value::asString)).sorted((v1, v2) -> Comparator.comparingDouble(Value::asNumber).compare((Value)v1, (Value)v2)).distinct().collect(Collectors.toList());
            List<Feature> featureXSvalues = xsValues.stream().map(v -> FeatureFactory.copyOf(featureDistribution.getFeature(), v)).collect(Collectors.toList());
            for (int outputIndex = 0; outputIndex < outputSize; ++outputIndex) {
                PartialDependenceGraph partialDependenceGraph = this.getPartialDependenceGraph(model, trainingData, xsValues, featureXSvalues, outputIndex);
                pdps.add(partialDependenceGraph);
            }
        }
        long end = System.currentTimeMillis();
        LOGGER.debug("explanation time: {}ms", (Object)(end - start));
        return pdps;
    }

    private PartialDependenceGraph getPartialDependenceGraph(PredictionProvider model, List<PredictionInput> trainingData, List<Value> xsValues, List<Feature> featureXSvalues, int outputIndex) throws InterruptedException, ExecutionException, TimeoutException {
        Output outputDecision = null;
        Feature feature = null;
        ArrayList<Map<Value, Long>> valueCounts = new ArrayList<Map<Value, Long>>(featureXSvalues.size());
        for (int i = 0; i < featureXSvalues.size(); ++i) {
            if (feature == null) {
                feature = FeatureFactory.copyOf(featureXSvalues.get(i), new Value(null));
            }
            List<PredictionInput> predictionInputs = this.prepareInputs(featureXSvalues.get(i), trainingData);
            List<PredictionOutput> predictionOutputs = this.getOutputs(model, predictionInputs);
            for (PredictionOutput predictionOutput : predictionOutputs) {
                Output output = predictionOutput.getOutputs().get(outputIndex);
                if (outputDecision == null) {
                    outputDecision = new Output(output.getName(), output.getType());
                }
                this.updateValueCounts(valueCounts, i, output);
            }
        }
        if (outputDecision != null) {
            List<Value> yValues = this.collapseMarginalImpacts(valueCounts, outputDecision.getType());
            return new PartialDependenceGraph(feature, outputDecision, xsValues, yValues);
        }
        throw new IllegalArgumentException("cannot produce PDP for null decision");
    }

    private List<Value> collapseMarginalImpacts(List<Map<Value, Long>> valueCounts, Type type) {
        List<Value> yValues = new ArrayList<Value>();
        if (Type.NUMBER.equals((Object)type)) {
            List doubles = valueCounts.stream().map(v -> v.entrySet().stream().map(e -> ((Value)e.getKey()).asNumber() * (double)((Long)e.getValue()).longValue() / (double)this.config.getSeriesLength()).mapToDouble(d -> d).sum()).collect(Collectors.toList());
            yValues = doubles.stream().map(Value::new).collect(Collectors.toList());
        } else {
            for (Map<Value, Long> item : valueCounts) {
                long max = 0L;
                String output = null;
                for (Map.Entry<Value, Long> entry : item.entrySet()) {
                    if (entry.getValue() <= max) continue;
                    max = entry.getValue();
                    output = entry.getKey().asString();
                }
                yValues.add(new Value(output));
            }
        }
        return yValues;
    }

    private void updateValueCounts(List<Map<Value, Long>> valueCounts, int featureValueIndex, Output output) {
        Value categoricalOutput = output.getValue();
        if (valueCounts.size() <= featureValueIndex) {
            HashMap<Value, Long> classCount = new HashMap<Value, Long>();
            classCount.put(categoricalOutput, 1L);
            valueCounts.add(classCount);
        } else {
            Map<Value, Long> classCount = valueCounts.get(featureValueIndex);
            if (classCount.containsKey(categoricalOutput)) {
                classCount.put(categoricalOutput, classCount.get(categoricalOutput) + 1L);
            } else {
                classCount.put(categoricalOutput, 1L);
            }
            valueCounts.set(featureValueIndex, classCount);
        }
    }

    private List<PredictionOutput> getOutputs(PredictionProvider model, List<PredictionInput> predictionInputs) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionOutput> predictionOutputs = model.predictAsync(predictionInputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        return predictionOutputs;
    }

    private List<PredictionInput> prepareInputs(Feature featureXs, List<PredictionInput> trainingData) {
        ArrayList<PredictionInput> predictionInputs = new ArrayList<PredictionInput>(this.config.getSeriesLength());
        for (PredictionInput trainingSample : trainingData) {
            List<Feature> features = trainingSample.getFeatures();
            List<Feature> newFeatures = this.replaceFeatures(featureXs, features);
            predictionInputs.add(new PredictionInput(newFeatures));
        }
        return predictionInputs;
    }

    private List<Feature> replaceFeatures(Feature featureXs, List<Feature> features) {
        ArrayList<Feature> newFeatures = new ArrayList<Feature>();
        for (Feature f : features) {
            Feature newFeature;
            if (f.getName().equals(featureXs.getName())) {
                newFeature = FeatureFactory.copyOf(f, featureXs.getValue());
            } else if (Type.COMPOSITE == f.getType()) {
                List elements = (List)f.getValue().getUnderlyingObject();
                newFeature = FeatureFactory.newCompositeFeature(f.getName(), this.replaceFeatures(featureXs, elements));
            } else {
                newFeature = FeatureFactory.copyOf(f, f.getValue());
            }
            newFeatures.add(newFeature);
        }
        return newFeatures;
    }
}

