/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;

public class ConfidencePredictingClassifier
extends Classifier {
    Classifier underlyingClassifier;
    Classifier confidencePredictingClassifier;
    double totalCorrect;
    double totalIncorrect;
    double totalIncorrectIncorrect;
    double totalIncorrectCorrect;
    int numCorrectInstances;
    int numIncorrectInstances;
    int numConfidenceCorrect;
    int numFalsePositive;
    int numFalseNegative;

    public ConfidencePredictingClassifier(Classifier underlyingClassifier, Classifier confidencePredictingClassifier) {
        super(underlyingClassifier.getInstancePipe());
        this.underlyingClassifier = underlyingClassifier;
        this.confidencePredictingClassifier = confidencePredictingClassifier;
        this.totalCorrect = 0.0;
        this.totalIncorrect = 0.0;
        this.totalIncorrectIncorrect = 0.0;
        this.totalIncorrectCorrect = 0.0;
        this.numCorrectInstances = 0;
        this.numIncorrectInstances = 0;
        this.numConfidenceCorrect = 0;
        this.numFalsePositive = 0;
        this.numFalseNegative = 0;
    }

    @Override
    public Classification classify(Instance instance) {
        Classification c = this.underlyingClassifier.classify(instance);
        Classification cpc = this.confidencePredictingClassifier.classify(c);
        LabelVector lv = c.getLabelVector();
        int bestIndex = lv.getBestIndex();
        double[] values = new double[lv.numLocations()];
        int i = 0;
        while (i < lv.numLocations()) {
            values[i] = i != bestIndex ? 0.0 : cpc.getLabelVector().value("correct");
            ++i;
        }
        if (c.bestLabelIsCorrect()) {
            ++this.numCorrectInstances;
            this.totalCorrect += cpc.getLabelVector().value("correct");
            this.totalIncorrectCorrect += cpc.getLabelVector().value("incorrect");
            String correct = new String("correct");
            if (correct.equals(cpc.getLabelVector().getBestLabel().toString())) {
                ++this.numConfidenceCorrect;
            } else {
                ++this.numFalseNegative;
            }
        } else {
            ++this.numIncorrectInstances;
            this.totalIncorrect += cpc.getLabelVector().value("correct");
            this.totalIncorrectIncorrect += cpc.getLabelVector().value("incorrect");
            if (new String("incorrect").equals(cpc.getLabelVector().getBestLabel().toString())) {
                ++this.numConfidenceCorrect;
            } else {
                ++this.numFalsePositive;
            }
        }
        return new Classification(instance, this, new LabelVector(lv.getLabelAlphabet(), values));
    }

    public void printAverageScores() {
        System.out.println("Mean score of correct for correct instances = " + this.meanCorrect());
        System.out.println("Mean score of correct for incorrect instances = " + this.meanIncorrect());
        System.out.println("Mean score of incorrect for correct instances = " + this.totalIncorrectCorrect / (double)this.numCorrectInstances);
        System.out.println("Mean score of incorrect for incorrect instances = " + this.totalIncorrectIncorrect / (double)this.numIncorrectInstances);
    }

    public void printConfidenceAccuracy() {
        System.out.println("Confidence predicting accuracy = " + (double)this.numConfidenceCorrect / (double)(this.numIncorrectInstances + this.numCorrectInstances) + " false negatives: " + this.numFalseNegative + "/" + this.numCorrectInstances + " false positives: " + this.numFalsePositive + " / " + this.numIncorrectInstances);
    }

    public double meanCorrect() {
        if (this.numCorrectInstances == 0) {
            return 0.0;
        }
        return this.totalCorrect / (double)this.numCorrectInstances;
    }

    public double meanIncorrect() {
        if (this.numIncorrectInstances == 0) {
            return 0.0;
        }
        return this.totalIncorrect / (double)this.numIncorrectInstances;
    }
}

