/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.global.pdp;

import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.global.GlobalExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.utils.DataUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PartialDependencePlotExplainer
implements GlobalExplainer<List<PartialDependenceGraph>> {
    private static final Logger LOGGER = LoggerFactory.getLogger(PartialDependencePlotExplainer.class);
    private static final int DEFAULT_SERIES_LENGTH = 100;
    private final int seriesLength;
    private final Random random;

    public PartialDependencePlotExplainer(int seriesLength, Random random) {
        this.seriesLength = seriesLength;
        this.random = random;
    }

    public PartialDependencePlotExplainer() {
        this(100, new SecureRandom());
    }

    @Override
    public List<PartialDependenceGraph> explain(PredictionProvider model, PredictionProviderMetadata metadata) throws InterruptedException, ExecutionException, TimeoutException {
        long start = System.currentTimeMillis();
        ArrayList<PartialDependenceGraph> pdps = new ArrayList<PartialDependenceGraph>();
        DataDistribution dataDistribution = metadata.getDataDistribution();
        int noOfFeatures = metadata.getInputShape().getFeatures().size();
        List<FeatureDistribution> featureDistributions = dataDistribution.getFeatureDistributions();
        for (int featureIndex = 0; featureIndex < noOfFeatures; ++featureIndex) {
            for (int outputIndex = 0; outputIndex < metadata.getOutputShape().getOutputs().size(); ++outputIndex) {
                double[] featureXSvalues = DataUtils.generateSamples(featureDistributions.get(featureIndex).getMin(), featureDistributions.get(featureIndex).getMax(), this.seriesLength);
                double[][] trainingData = this.generateDistributions(noOfFeatures, featureDistributions);
                double[] marginalImpacts = new double[featureXSvalues.length];
                for (int i = 0; i < featureXSvalues.length; ++i) {
                    List<PredictionInput> predictionInputs = this.prepareInputs(noOfFeatures, featureIndex, featureXSvalues, trainingData, i);
                    List<PredictionOutput> predictionOutputs = this.getOutputs(model, predictionInputs);
                    for (PredictionOutput predictionOutput : predictionOutputs) {
                        Output output = predictionOutput.getOutputs().get(outputIndex);
                        double v = output.getValue().asNumber();
                        if (Double.isNaN(v)) {
                            v = output.getScore();
                        }
                        int n = i;
                        marginalImpacts[n] = marginalImpacts[n] + v / (double)this.seriesLength;
                    }
                }
                PartialDependenceGraph partialDependenceGraph = new PartialDependenceGraph(metadata.getInputShape().getFeatures().get(featureIndex), featureXSvalues, marginalImpacts);
                pdps.add(partialDependenceGraph);
            }
        }
        long end = System.currentTimeMillis();
        LOGGER.debug("explanation time: {}ms", (Object)(end - start));
        return pdps;
    }

    private List<PredictionOutput> getOutputs(PredictionProvider model, List<PredictionInput> predictionInputs) throws InterruptedException, ExecutionException, TimeoutException {
        List<PredictionOutput> predictionOutputs;
        try {
            predictionOutputs = model.predictAsync(predictionInputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        }
        catch (InterruptedException | ExecutionException | TimeoutException e) {
            LOGGER.error("Impossible to obtain prediction {}", (Object)e.getMessage());
            throw e;
        }
        return predictionOutputs;
    }

    private List<PredictionInput> prepareInputs(int noOfFeatures, int featureIndex, double[] featureXSvalues, double[][] trainingData, int i) {
        ArrayList<PredictionInput> predictionInputs = new ArrayList<PredictionInput>(this.seriesLength);
        double[] inputs = new double[noOfFeatures];
        inputs[featureIndex] = featureXSvalues[i];
        for (int j = 0; j < this.seriesLength; ++j) {
            for (int f = 0; f < noOfFeatures; ++f) {
                if (f == featureIndex) continue;
                inputs[f] = trainingData[f][j];
            }
            PredictionInput input = new PredictionInput(DataUtils.doublesToFeatures(inputs));
            predictionInputs.add(input);
        }
        return predictionInputs;
    }

    private double[][] generateDistributions(int noOfFeatures, List<FeatureDistribution> featureDistributions) {
        double[][] trainingData = new double[noOfFeatures][this.seriesLength];
        for (int i = 0; i < noOfFeatures; ++i) {
            double[] featureData = DataUtils.generateData(featureDistributions.get(i).getMean(), featureDistributions.get(i).getStdDev(), this.seriesLength, this.random);
            trainingData[i] = featureData;
        }
        return trainingData;
    }
}

