/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.model;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Value;

public class Saliency {
    private final Output output;
    private final List<FeatureImportance> perFeatureImportance;

    public Saliency(Output output, List<FeatureImportance> perFeatureImportance) {
        this.output = output;
        this.perFeatureImportance = Collections.unmodifiableList(perFeatureImportance);
    }

    public List<FeatureImportance> getPerFeatureImportance() {
        return this.perFeatureImportance;
    }

    public Output getOutput() {
        return this.output;
    }

    public List<FeatureImportance> getTopFeatures(int k) {
        return this.perFeatureImportance.stream().sorted((f0, f1) -> Double.compare(Math.abs(f1.getScore()), Math.abs(f0.getScore()))).limit(k).collect(Collectors.toList());
    }

    public List<FeatureImportance> getPositiveFeatures(int k) {
        return this.perFeatureImportance.stream().sorted((f0, f1) -> Double.compare(Math.abs(f1.getScore()), Math.abs(f0.getScore()))).filter(f -> f.getScore() >= 0.0).limit(k).collect(Collectors.toList());
    }

    public List<FeatureImportance> getNegativeFeatures(int k) {
        return this.perFeatureImportance.stream().sorted((f0, f1) -> Double.compare(Math.abs(f1.getScore()), Math.abs(f0.getScore()))).filter(f -> f.getScore() < 0.0).limit(k).collect(Collectors.toList());
    }

    public String toString() {
        return "Saliency{output=" + this.output + ", perFeatureImportance=" + this.perFeatureImportance + "}";
    }

    public static Map<String, Saliency> merge(Collection<Map<String, Saliency>> saliencies) {
        HashMap<String, Saliency> finalResult = new HashMap<String, Saliency>();
        Map<Output, List<Saliency>> flatten = saliencies.stream().map(Map::values).flatMap(Collection::stream).collect(Collectors.groupingBy(Saliency::getOutput));
        for (Map.Entry<Output, List<Saliency>> saliencyEntry : flatten.entrySet()) {
            ArrayList<FeatureImportance> result = new ArrayList<FeatureImportance>();
            List fis = saliencyEntry.getValue().stream().map(s -> s.perFeatureImportance).flatMap(Collection::stream).collect(Collectors.toList());
            Map<Feature, List<FeatureImportance>> collect = fis.stream().collect(Collectors.groupingBy(fi -> FeatureFactory.copyOf(fi.getFeature(), new Value<Object>(null))));
            for (Map.Entry<Feature, List<FeatureImportance>> entry : collect.entrySet()) {
                double meanScore = entry.getValue().stream().map(FeatureImportance::getScore).flatMapToDouble(DoubleStream::of).average().orElse(0.0);
                result.add(new FeatureImportance(entry.getKey(), meanScore));
            }
            result.sort(Comparator.comparing(f -> f.getFeature().getName()));
            finalResult.put(saliencyEntry.getKey().getName(), new Saliency(saliencyEntry.getKey(), result));
        }
        return finalResult;
    }
}

