/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Assertions;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.ValidationUtils;
import org.mockito.Mockito;

public class TestUtils {
    public static PredictionProvider getFeaturePassModel(int featureIndex) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                Feature feature = (Feature)features.get(featureIndex);
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("feature-" + featureIndex, feature.getType(), feature.getValue(), 1.0)));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getSumSkipModel(int skipFeatureIndex) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                double result = 0.0;
                for (int i = 0; i < features.size(); ++i) {
                    if (skipFeatureIndex == i) continue;
                    result += ((Feature)features.get(i)).getValue().asNumber();
                }
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("sum-but" + skipFeatureIndex, Type.NUMBER, new Value((Object)result), 1.0)));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getNoisySumModel(Random rn, double noiseMagnitude) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                double result = 0.0;
                for (int i = 0; i < features.size(); ++i) {
                    result += ((Feature)features.get(i)).getValue().asNumber() + (rn.nextDouble() - 0.5) * noiseMagnitude;
                }
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("noisy-sum", Type.NUMBER, new Value((Object)result), 1.0)));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getLinearModel(double[] weights) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                double result = 0.0;
                for (int i = 0; i < features.size(); ++i) {
                    result += ((Feature)features.get(i)).getValue().asNumber() * weights[i];
                }
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("linear-sum", Type.NUMBER, new Value((Object)result), 1.0)));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getSumSkipTwoOutputModel(int skipFeatureIndex) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                double result = 0.0;
                for (int i = 0; i < features.size(); ++i) {
                    if (skipFeatureIndex == i) continue;
                    result += ((Feature)features.get(i)).getValue().asNumber();
                }
                Output output0 = new Output("sum-but" + skipFeatureIndex, Type.NUMBER, new Value((Object)result), 1.0);
                Output output1 = new Output("sum-but" + skipFeatureIndex + "*2", Type.NUMBER, new Value((Object)(result * 2.0)), 1.0);
                PredictionOutput predictionOutput = new PredictionOutput(List.of(output0, output1));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getFeatureSkipModel(int featureIndex) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                ArrayList<Output> outputs = new ArrayList<Output>();
                for (int i = 0; i < features.size(); ++i) {
                    if (i == featureIndex) continue;
                    Feature feature = (Feature)features.get(i);
                    outputs.add(new Output(feature.getName(), feature.getType(), feature.getValue(), 1.0));
                }
                PredictionOutput predictionOutput = new PredictionOutput(outputs);
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getEvenFeatureModel(int featureIndex) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                Feature feature = (Feature)features.get(featureIndex);
                double v = feature.getValue().asNumber();
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("feature-" + featureIndex, Type.BOOLEAN, new Value((Object)(v % 2.0 == 0.0 ? 1 : 0)), 1.0)));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getEvenSumModel(int skipFeatureIndex) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                double result = 0.0;
                for (int i = 0; i < features.size(); ++i) {
                    if (skipFeatureIndex == i) continue;
                    result += ((Feature)features.get(i)).getValue().asNumber();
                }
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("sum-even-but" + skipFeatureIndex, Type.BOOLEAN, new Value((Object)((int)result % 2 == 0 ? 1 : 0)), 1.0)));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getSumThresholdModel(double center, double epsilon) {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                double result = 0.0;
                for (int i = 0; i < features.size(); ++i) {
                    result += ((Feature)features.get(i)).getValue().asNumber();
                }
                boolean inside = result >= center - epsilon && result <= center + epsilon;
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("inside", Type.BOOLEAN, new Value((Object)inside), 1.0 - Math.abs(result - center))));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getDummyTextClassifier() {
        List<String> blackList = Arrays.asList("money", "$", "\u00a3", "bitcoin");
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> outputs = new LinkedList<PredictionOutput>();
            for (PredictionInput input : inputs) {
                boolean spam = false;
                block1: for (Feature f : input.getFeatures()) {
                    String[] words;
                    if (spam) continue;
                    String s = f.getValue().asString();
                    for (String w : words = s.split(" ")) {
                        if (!blackList.contains(w)) continue;
                        spam = true;
                        continue block1;
                    }
                }
                Output output = new Output("spam", Type.BOOLEAN, new Value((Object)spam), 1.0);
                outputs.add(new PredictionOutput(List.of(output)));
            }
            return outputs;
        });
    }

    public static PredictionProvider getSymbolicArithmeticModel() {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> predictionOutputs = new LinkedList<PredictionOutput>();
            String OPERAND_FEATURE_NAME = "operand";
            for (PredictionInput predictionInput : inputs) {
                List features = predictionInput.getFeatures();
                Optional<String> operand = features.stream().filter(f -> "operand".equals(f.getName())).map(f -> f.getValue().asString()).findFirst();
                if (!operand.isPresent()) {
                    throw new IllegalArgumentException("No valid operand found in features");
                }
                String operandValue = operand.get();
                double result = 0.0;
                for (Feature feature : features) {
                    if ("operand".equals(feature.getName())) continue;
                    switch (operandValue) {
                        case "+": {
                            result += feature.getValue().asNumber();
                            break;
                        }
                        case "-": {
                            result -= feature.getValue().asNumber();
                            break;
                        }
                        case "*": {
                            result *= feature.getValue().asNumber();
                            break;
                        }
                        case "/": {
                            result /= feature.getValue().asNumber();
                        }
                    }
                }
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value((Object)result), 1.0)));
                predictionOutputs.add(predictionOutput);
            }
            return predictionOutputs;
        });
    }

    public static PredictionProvider getFixedOutputClassifier() {
        return inputs -> CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> outputs = new LinkedList<PredictionOutput>();
            for (PredictionInput ignored : inputs) {
                Output output = new Output("class", Type.BOOLEAN, new Value((Object)false), 1.0);
                outputs.add(new PredictionOutput(List.of(output)));
            }
            return outputs;
        });
    }

    public static Feature getMockedNumericFeature() {
        return TestUtils.getMockedNumericFeature(1.0);
    }

    public static Feature getMockedFeature(Type type, Value v) {
        Feature f = (Feature)Mockito.mock(Feature.class);
        Mockito.when((Object)f.getType()).thenReturn((Object)type);
        Mockito.when((Object)f.getName()).thenReturn((Object)("f-" + type.name()));
        Mockito.when((Object)f.getValue()).thenReturn((Object)v);
        return f;
    }

    public static Feature getMockedTextFeature(String s) {
        Feature f = (Feature)Mockito.mock(Feature.class);
        Mockito.when((Object)f.getType()).thenReturn((Object)Type.TEXT);
        Mockito.when((Object)f.getName()).thenReturn((Object)"f-text");
        Value value = (Value)Mockito.mock(Value.class);
        Mockito.when((Object)value.getUnderlyingObject()).thenReturn((Object)s);
        Mockito.when((Object)value.asNumber()).thenReturn((Object)Double.NaN);
        Mockito.when((Object)value.asString()).thenReturn((Object)s);
        Mockito.when((Object)f.getValue()).thenReturn((Object)value);
        return f;
    }

    public static Feature getMockedNumericFeature(double d) {
        Feature f = (Feature)Mockito.mock(Feature.class);
        Mockito.when((Object)f.getType()).thenReturn((Object)Type.NUMBER);
        Mockito.when((Object)f.getName()).thenReturn((Object)"f-num");
        Value value = (Value)Mockito.mock(Value.class);
        Mockito.when((Object)value.getUnderlyingObject()).thenReturn((Object)d);
        Mockito.when((Object)value.asNumber()).thenReturn((Object)d);
        Mockito.when((Object)value.asString()).thenReturn((Object)String.valueOf(d));
        Mockito.when((Object)f.getValue()).thenReturn((Object)value);
        return f;
    }

    public static void assertLimeStability(PredictionProvider model, Prediction prediction, LimeExplainer limeExplainer, int topK, double minimumPositiveStabilityRate, double minimumNegativeStabilityRate) {
        Assertions.assertDoesNotThrow(() -> ValidationUtils.validateLocalSaliencyStability((PredictionProvider)model, (Prediction)prediction, (LocalExplainer)limeExplainer, (int)topK, (double)minimumPositiveStabilityRate, (double)minimumNegativeStabilityRate));
    }

    public static void fillBalancedDataForFiltering(int size, List<Pair<double[], Double>> trainingSet, double[] weights) {
        for (int i = 0; i < size; ++i) {
            double[] x = new double[2];
            for (int j = 0; j < 2; ++j) {
                x[j] = (i + j) % 2 == 0 ? 0.0 : 1.0;
            }
            Double y = i % 3 == 0 ? 0.0 : 1.0;
            trainingSet.add((Pair<double[], Double>)Pair.of((Object)x, (Object)y));
            weights[i] = i % 2 == 0 ? 0.2 : 0.8;
        }
    }
}

