package org.deeplearning4j.eval;

import java.util.Iterator;
import java.util.SortedSet;
import org.deeplearning4j.berkeley.Counter;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/eval/Evaluation.class */
public class Evaluation {
    private Counter<Integer> truePositives = new Counter<>();
    private Counter<Integer> falsePositives = new Counter<>();
    private Counter<Integer> trueNegative = new Counter<>();
    private Counter<Integer> falseNegatives = new Counter<>();
    private ConfusionMatrix<Integer> confusion = new ConfusionMatrix<>();

    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.length() != iNDArray2.length()) {
            throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
        }
        for (int i = 0; i < iNDArray.rows(); i++) {
            INDArray row = iNDArray.getRow(i);
            INDArray row2 = iNDArray2.getRow(i);
            double d = row.getDouble(0);
            int i2 = 0;
            for (int i3 = 1; i3 < row.columns(); i3++) {
                if (row.getDouble(i3) > d) {
                    d = row.getDouble(i3);
                    i2 = i3;
                }
            }
            double d2 = row2.getDouble(0);
            int i4 = 0;
            for (int i5 = 1; i5 < row2.columns(); i5++) {
                if (row2.getDouble(i5) > d2) {
                    d2 = row2.getDouble(i5);
                    i4 = i5;
                }
            }
            addToConfusion(i2, i4);
            if (i2 == i4) {
                incrementTruePositives(i4);
                for (Integer num : this.confusion.getClasses()) {
                    if (num.intValue() != i4) {
                        this.trueNegative.incrementCount(num, 1.0d);
                    }
                }
            } else {
                incrementFalseNegatives(i2);
                incrementFalsePositives(i4);
            }
        }
    }

    public String stats() {
        StringBuilder append = new StringBuilder().append("\n");
        SortedSet<Integer> classes = this.confusion.getClasses();
        for (Integer num : classes) {
            for (Integer num2 : classes) {
                int count = this.confusion.getCount(num, num2);
                if (count != 0) {
                    append.append("\nActual Class " + num + " was predicted with Predicted " + num2 + " with count " + count + " times\n");
                }
            }
        }
        append.append("\n==========================F1 Scores========================================");
        append.append("\n " + f1());
        append.append("\n===========================================================================");
        return append.toString();
    }

    public void addToConfusion(int i, int i2) {
        this.confusion.add(Integer.valueOf(i), Integer.valueOf(i2));
    }

    public int classCount(int i) {
        return this.confusion.getActualTotal(Integer.valueOf(i));
    }

    public int numtimesPredicted(int i) {
        return this.confusion.getPredictedTotal(Integer.valueOf(i));
    }

    public int numTimesPredicted(int i, int i2) {
        return this.confusion.getCount(Integer.valueOf(i), Integer.valueOf(i2));
    }

    public double precision() {
        double d = 0.0d;
        Iterator<Integer> it = this.confusion.getClasses().iterator();
        while (it.hasNext()) {
            d += precision(it.next().intValue());
        }
        return d / this.confusion.getClasses().size();
    }

    public double trueNegatives() {
        return this.trueNegative.totalCount();
    }

    public double falsePositive() {
        return this.falsePositives.totalCount();
    }

    public double negative() {
        return trueNegatives() + this.falseNegatives.totalCount();
    }

    public double positive() {
        return this.truePositives.totalCount() + this.falseNegatives.totalCount();
    }

    public double accuracy() {
        return (this.truePositives.totalCount() + trueNegatives()) / (positive() + negative());
    }

    public double f1() {
        double precision = precision();
        double recall = recall();
        if (precision == 0.0d || recall == 0.0d) {
            return 0.0d;
        }
        return 2.0d * ((precision * recall) / (precision + recall));
    }

    public double f1(int i) {
        double precision = precision(i);
        double recall = recall();
        if (precision == 0.0d || recall == 0.0d) {
            return 0.0d;
        }
        return 2.0d * ((precision * recall) / (precision + recall));
    }

    public double recall() {
        double d = 0.0d;
        Iterator<Integer> it = this.confusion.getClasses().iterator();
        while (it.hasNext()) {
            d += recall(it.next().intValue());
        }
        return d / this.confusion.getClasses().size();
    }

    public double recall(int i) {
        if (this.truePositives.getCount(Integer.valueOf(i)) == 0.0d) {
            return 0.0d;
        }
        return this.truePositives.getCount(Integer.valueOf(i)) / (this.truePositives.getCount(Integer.valueOf(i)) + this.falseNegatives.getCount(Integer.valueOf(i)));
    }

    public double precision(int i) {
        if (this.truePositives.getCount(Integer.valueOf(i)) == 0.0d) {
            return 0.0d;
        }
        return this.truePositives.getCount(Integer.valueOf(i)) / (this.truePositives.getCount(Integer.valueOf(i)) + this.falsePositives.getCount(Integer.valueOf(i)));
    }

    public void incrementTruePositives(int i) {
        this.truePositives.incrementCount(Integer.valueOf(i), 1.0d);
    }

    public void incrementFalseNegatives(int i) {
        this.falseNegatives.incrementCount(Integer.valueOf(i), 1.0d);
    }

    public void incrementFalsePositives(int i) {
        this.falsePositives.incrementCount(Integer.valueOf(i), 1.0d);
    }
}
