/*
 * Decompiled with CFR 0.152.
 */
package de.unidue.ltl.evaluation.visualization;

import de.tudarmstadt.ukp.dkpro.core.api.frequency.util.ConditionalFrequencyDistribution;
import de.tudarmstadt.ukp.dkpro.core.api.frequency.util.FrequencyDistribution;
import de.unidue.ltl.evaluation.core.AbstractConfusionMatrix;
import de.unidue.ltl.evaluation.core.EvaluationData;
import de.unidue.ltl.evaluation.core.EvaluationEntry;
import de.vandermeer.asciitable.v2.V2_AsciiTable;
import de.vandermeer.asciitable.v2.render.V2_AsciiTableRenderer;
import de.vandermeer.asciitable.v2.render.V2_Width;
import de.vandermeer.asciitable.v2.render.WidthLongestWord;
import de.vandermeer.asciitable.v2.themes.V2_E_TableThemes;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class ConfusionMatrix<T>
extends AbstractConfusionMatrix<T> {
    Set<T> allLabels = new HashSet<T>();
    ConditionalFrequencyDistribution<T, T> cfd = new ConditionalFrequencyDistribution();

    public ConfusionMatrix(EvaluationData<T> data) {
        for (EvaluationEntry e : data) {
            Object gold = e.getGold();
            Object predicted = e.getPredicted();
            this.cfd.addSample(gold, predicted, 1L);
            this.allLabels.add(gold);
            this.allLabels.add(predicted);
        }
    }

    public String toText() {
        List<T> labels = this.getLabels();
        int[][] array = this.getTwoDimensionalArray();
        V2_AsciiTable at = new V2_AsciiTable();
        at.addStrongRule();
        at.addRow(this.getTableHeader("Predicted"));
        at.addRow(this.getLabelHeader(labels));
        at.addStrongRule();
        for (int i = 0; i < labels.size(); ++i) {
            T label = labels.get(i);
            Object[] values = new Object[labels.size() + 1];
            values[0] = label;
            for (int j = 0; j < labels.size(); ++j) {
                values[j + 1] = array[i][j];
            }
            at.addRow(values);
            at.addRule();
        }
        at.addStrongRule();
        V2_AsciiTableRenderer rend = new V2_AsciiTableRenderer();
        rend.setTheme(V2_E_TableThemes.NO_BORDERS.get());
        rend.setWidth((V2_Width)new WidthLongestWord());
        return rend.render(at).toString();
    }

    private Object[] getTableHeader(String title) {
        Object[] header = new Object[this.allLabels.size() + 1];
        for (int i = 0; i < header.length; ++i) {
            header[i] = null;
        }
        header[0] = "";
        header[header.length - 1] = title;
        return header;
    }

    private Object[] getLabelHeader(List<T> labels) {
        ArrayList<String> values = new ArrayList<String>();
        values.add("");
        values.addAll(labels);
        return values.toArray();
    }

    public long getNumberOfEntries() {
        return this.cfd.getN();
    }

    public long getNumberOfConfusions(T goldLabel, T confusedLabel) {
        return this.cfd.getCount(goldLabel, confusedLabel);
    }

    public long getTruePositives(T label) {
        return this.cfd.getCount(label, label);
    }

    public long getFalseNegatives(T label) {
        FrequencyDistribution fd = this.cfd.getFrequencyDistribution(label);
        long total = 0L;
        if (fd == null) {
            return total;
        }
        for (Object key : fd.getKeys()) {
            if (key.equals(label)) continue;
            total += fd.getCount(key);
        }
        return total;
    }

    public long getFalsePositives(T label) {
        long total = 0L;
        for (Object c : this.cfd.getConditions()) {
            if (c.equals(label)) continue;
            total += this.cfd.getFrequencyDistribution(c).getCount(label);
        }
        return total;
    }

    public long getTrueNegatives(T label) {
        long total = 0L;
        for (Object c : this.cfd.getConditions()) {
            if (c.equals(label)) continue;
            FrequencyDistribution fd = this.cfd.getFrequencyDistribution(c);
            for (Object key : fd.getKeys()) {
                if (key.equals(label)) continue;
                total += fd.getCount(key);
            }
        }
        return total;
    }

    public List<T> getLabels() {
        ArrayList<T> labels = new ArrayList<T>(this.allLabels);
        Collections.sort(labels, new Comparator<T>(){

            @Override
            public int compare(T o1, T o2) {
                if (o1.equals(o2)) {
                    return 0;
                }
                return o1.toString().compareTo(o2.toString());
            }
        });
        return labels;
    }

    public int[][] getTwoDimensionalArray() {
        List<T> labels = this.getLabels();
        int n = labels.size();
        int[][] array = new int[n][n];
        int i = 0;
        for (T key : labels) {
            FrequencyDistribution fd = this.cfd.getFrequencyDistribution(key);
            int j = 0;
            for (T t : labels) {
                array[i][j] = fd != null ? Long.valueOf(fd.getCount(t)).intValue() : 0;
                ++j;
            }
            ++i;
        }
        return array;
    }
}

