/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.local.lime;

import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.kie.kogito.explainability.local.LocalExplanationException;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.utils.DataUtils;

class DatasetEncoder {
    private static final double CLUSTER_THRESHOLD = 0.001;
    private final List<PredictionInput> perturbedInputs;
    private final List<Output> predictedOutputs;
    private final PredictionInput targetInput;
    private final Output originalOutput;

    DatasetEncoder(List<PredictionInput> perturbedInputs, List<Output> perturbedOutputs, PredictionInput targetInput, Output targetOutput) {
        this.perturbedInputs = perturbedInputs;
        this.predictedOutputs = perturbedOutputs;
        this.targetInput = targetInput;
        this.originalOutput = targetOutput;
    }

    List<Pair<double[], Double>> getEncodedTrainingSet() {
        LinkedList<Pair<double[], Double>> trainingSet = new LinkedList<Pair<double[], Double>>();
        List<PredictionInput> flatInputs = DataUtils.linearizeInputs(this.perturbedInputs);
        if (!(flatInputs.isEmpty() || this.predictedOutputs.isEmpty() || this.targetInput.getFeatures().isEmpty() || this.originalOutput == null)) {
            List<List<Double>> columnData = this.getColumnData(flatInputs);
            int pi = 0;
            for (Output output : this.predictedOutputs) {
                double y;
                double[] x = new double[columnData.size()];
                int i = 0;
                for (List<Double> column : columnData) {
                    x[i] = column.get(pi);
                    ++i;
                }
                if (Type.NUMBER.equals((Object)this.originalOutput.getType()) || Type.BOOLEAN.equals((Object)this.originalOutput.getType())) {
                    y = output.getValue().asNumber();
                } else {
                    Object originalObject = this.originalOutput.getValue().getUnderlyingObject();
                    Object outputObject = output.getValue().getUnderlyingObject();
                    y = originalObject == null || outputObject == null ? (originalObject == outputObject ? 1.0 : 0.0) : (originalObject.equals(outputObject) ? 1.0 : 0.0);
                }
                ImmutablePair<double[], Double> sample = new ImmutablePair<double[], Double>(x, y);
                trainingSet.add(sample);
                ++pi;
            }
        }
        return trainingSet;
    }

    private List<List<Double>> getColumnData(List<PredictionInput> perturbedInputs) {
        LinkedList<List<Double>> columnData = new LinkedList<List<Double>>();
        block6: for (int t = 0; t < this.targetInput.getFeatures().size(); ++t) {
            Feature originalFeature = this.targetInput.getFeatures().get(t);
            switch (originalFeature.getType()) {
                case NUMBER: {
                    DatasetEncoder.encodeNumbers(perturbedInputs, this.targetInput, columnData, t);
                    continue block6;
                }
                case TEXT: {
                    DatasetEncoder.encodeText(perturbedInputs, columnData, originalFeature);
                    continue block6;
                }
                case CATEGORICAL: 
                case BINARY: 
                case TIME: 
                case URI: 
                case DURATION: 
                case VECTOR: 
                case CURRENCY: 
                case UNDEFINED: {
                    DatasetEncoder.encodeEquals(perturbedInputs, columnData, t, originalFeature);
                    continue block6;
                }
                case BOOLEAN: {
                    LinkedList<Double> featureValues = new LinkedList<Double>();
                    for (PredictionInput pi : perturbedInputs) {
                        featureValues.add(pi.getFeatures().get(t).getValue().asNumber());
                    }
                    columnData.add(featureValues);
                    continue block6;
                }
                default: {
                    throw new LocalExplanationException("could not encoded features of type " + originalFeature.getType());
                }
            }
        }
        return columnData;
    }

    private static void encodeNumbers(List<PredictionInput> predictionInputs, PredictionInput originalInputs, List<List<Double>> columnData, int t) {
        double originalValue;
        double[] doubles = new double[predictionInputs.size() + 1];
        int i = 0;
        for (PredictionInput pi : predictionInputs) {
            Feature feature = pi.getFeatures().get(t);
            doubles[i] = feature.getValue().asNumber();
            ++i;
        }
        Feature feature = originalInputs.getFeatures().get(t);
        doubles[i] = originalValue = feature.getValue().asNumber();
        double min = DoubleStream.of(doubles).min().getAsDouble();
        double max = DoubleStream.of(doubles).max().getAsDouble();
        double threshold = DataUtils.gaussianKernel((originalValue - min) / (max - min), 0.0, 1.0);
        List featureValues = DoubleStream.of(doubles).map(d -> (d - min) / (max - min)).map(d -> Double.isNaN(d) ? 1.0 : d).boxed().map(d -> DataUtils.gaussianKernel(d, 0.0, 1.0)).map(d -> d - threshold < 0.001 ? 1.0 : 0.0).collect(Collectors.toList());
        columnData.add(featureValues);
    }

    private static void encodeText(List<PredictionInput> predictionInputs, List<List<Double>> columnData, Feature originalFeature) {
        String[] words;
        String originalString = originalFeature.getValue().asString();
        for (String word : words = originalString.split(" ")) {
            LinkedList<Double> featureValues = new LinkedList<Double>();
            for (PredictionInput pi : predictionInputs) {
                double featureValue;
                Feature feature = pi.getFeatures().stream().filter(f -> f.getName().equals(originalFeature.getName())).findFirst().orElse(null);
                if (feature != null && feature.getName().equals(originalFeature.getName())) {
                    String perturbedString = feature.getValue().asString();
                    String[] perturbedWords = perturbedString.split(" ");
                    featureValue = 0.0;
                    for (String w : perturbedWords) {
                        if (!w.equals(word)) continue;
                        featureValue = 1.0;
                        break;
                    }
                } else {
                    featureValue = 0.0;
                }
                featureValues.add(featureValue);
            }
            columnData.add(featureValues);
        }
    }

    private static void encodeEquals(List<PredictionInput> predictionInputs, List<List<Double>> columnData, int t, Feature originalFeature) {
        Object originalObject = originalFeature.getValue().getUnderlyingObject();
        LinkedList<Double> featureValues = new LinkedList<Double>();
        for (PredictionInput pi : predictionInputs) {
            double featureValue = originalObject.equals(pi.getFeatures().get(t).getValue().getUnderlyingObject()) ? 1.0 : 0.0;
            featureValues.add(featureValue);
        }
        columnData.add(featureValues);
    }
}

