/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.DataUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HighScoreNumericFeatureZonesProvider {
    private static final Logger LOGGER = LoggerFactory.getLogger(HighScoreNumericFeatureZonesProvider.class);

    private HighScoreNumericFeatureZonesProvider() {
    }

    public static Map<String, HighScoreNumericFeatureZones> getHighScoreFeatureZones(DataDistribution dataDistribution, PredictionProvider predictionProvider, List<Feature> features, int maxNoOfSamples) {
        double min;
        double max;
        HashMap<String, HighScoreNumericFeatureZones> numericFeatureZonesMap = new HashMap<String, HighScoreNumericFeatureZones>();
        List<Object> scoreSortedPredictions = new ArrayList<Prediction>();
        try {
            scoreSortedPredictions.addAll(DataUtils.getScoreSortedPredictions(predictionProvider, new PredictionInputsDataDistribution(dataDistribution.sample(maxNoOfSamples))));
        }
        catch (ExecutionException e) {
            LOGGER.error("Could not sort predictions by score {}", (Object)e.getMessage());
        }
        catch (InterruptedException e) {
            LOGGER.error("Interrupted while waiting for sorting predictions by score {}", (Object)e.getMessage());
            Thread.currentThread().interrupt();
        }
        catch (TimeoutException e) {
            LOGGER.error("Timed out while waiting for sorting predictions by score", (Throwable)e);
        }
        if (!scoreSortedPredictions.isEmpty() && (max = ((Prediction)scoreSortedPredictions.get(0)).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum()) != (min = ((Prediction)scoreSortedPredictions.get(scoreSortedPredictions.size() - 1)).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum())) {
            double threshold = scoreSortedPredictions.stream().map(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum()).mapToDouble(d -> d).average().orElse((max + min) / 2.0);
            scoreSortedPredictions = scoreSortedPredictions.stream().filter(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum() > threshold).collect(Collectors.toList());
            for (int j = 0; j < features.size(); ++j) {
                Feature feature = features.get(j);
                if (!Type.NUMBER.equals((Object)feature.getType())) continue;
                int finalJ = j;
                List topValues = scoreSortedPredictions.stream().map(prediction -> prediction.getInput().getFeatures().get(finalJ).getValue().asNumber()).distinct().collect(Collectors.toList());
                double[] highScoreFeaturePoints = topValues.stream().flatMapToDouble(DoubleStream::of).toArray();
                double center = DataUtils.getMean(highScoreFeaturePoints);
                double tolerance = DataUtils.getStdDev(highScoreFeaturePoints, center) / 2.0;
                HighScoreNumericFeatureZones highScoreNumericFeatureZones = new HighScoreNumericFeatureZones(highScoreFeaturePoints, tolerance);
                numericFeatureZonesMap.put(feature.getName(), highScoreNumericFeatureZones);
            }
        }
        return numericFeatureZonesMap;
    }
}

