/*
 * Copyright 2020 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.kie.kogito.explainability.model;

import java.util.List;
import java.util.stream.Collectors;

/**
 * The saliency generated by an explanation algorithm.
 * A saliency contains a feature importance for each explained feature.
 */
public class Saliency {

    private final Output output;
    private final List<FeatureImportance> perFeatureImportance;

    public Saliency(Output output, List<FeatureImportance> perFeatureImportance) {
        this.output = output;
        this.perFeatureImportance = perFeatureImportance;
    }

    public List<FeatureImportance> getPerFeatureImportance() {
        return perFeatureImportance;
    }

    public Output getOutput() {
        return output;
    }

    public List<FeatureImportance> getTopFeatures(int k) {
        return perFeatureImportance.stream().sorted((f0, f1) -> Double.compare(
                Math.abs(f1.getScore()), Math.abs(f0.getScore()))).limit(k).collect(Collectors.toList());
    }

    public List<FeatureImportance> getPositiveFeatures(int k) {
        return perFeatureImportance.stream().sorted((f0, f1) -> Double.compare(
                Math.abs(f1.getScore()), Math.abs(f0.getScore()))).filter(f -> f.getScore() >= 0).limit(k).collect(Collectors.toList());
    }

    public List<FeatureImportance> getNegativeFeatures(int k) {
        return perFeatureImportance.stream().sorted((f0, f1) -> Double.compare(
                Math.abs(f1.getScore()), Math.abs(f0.getScore()))).filter(f -> f.getScore() < 0).limit(k).collect(Collectors.toList());
    }

    @Override
    public String toString() {
        return "Saliency{" +
                "perFeatureImportance=" + perFeatureImportance +
                '}';
    }
}

