/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.TokenAccuracyEvaluator;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.text.DecimalFormat;
import java.util.logging.Logger;

public class PerClassAccuracyEvaluator
extends TransducerEvaluator {
    private static Logger logger = MalletLogger.getLogger(TokenAccuracyEvaluator.class.getName());

    public PerClassAccuracyEvaluator(InstanceList[] instanceLists, String[] descriptions) {
        super(instanceLists, descriptions);
    }

    public PerClassAccuracyEvaluator(InstanceList i1, String d1) {
        this(new InstanceList[]{i1}, new String[]{d1});
    }

    public PerClassAccuracyEvaluator(InstanceList i1, String d1, InstanceList i2, String d2) {
        this(new InstanceList[]{i1, i2}, new String[]{d1, d2});
    }

    @Override
    public void evaluateInstanceList(TransducerTrainer tt, InstanceList data, String description) {
        Transducer model = tt.getTransducer();
        Alphabet dict = model.getInputPipe().getTargetAlphabet();
        int numLabels = dict.size();
        int[] numCorrectTokens = new int[numLabels];
        int[] numPredTokens = new int[numLabels];
        int[] numTrueTokens = new int[numLabels];
        logger.info("Per-token results for " + description);
        int i = 0;
        while (i < data.size()) {
            Instance instance = (Instance)data.get(i);
            Sequence input = (Sequence)instance.getData();
            Sequence trueOutput = (Sequence)instance.getTarget();
            assert (input.size() == trueOutput.size());
            Sequence predOutput = model.transduce(input);
            assert (predOutput.size() == trueOutput.size());
            int j = 0;
            while (j < trueOutput.size()) {
                int idx;
                int n = idx = dict.lookupIndex(trueOutput.get(j));
                numTrueTokens[n] = numTrueTokens[n] + 1;
                int n2 = dict.lookupIndex(predOutput.get(j));
                numPredTokens[n2] = numPredTokens[n2] + 1;
                if (trueOutput.get(j).equals(predOutput.get(j))) {
                    int n3 = idx;
                    numCorrectTokens[n3] = numCorrectTokens[n3] + 1;
                }
                ++j;
            }
            ++i;
        }
        DecimalFormat f = new DecimalFormat("0.####");
        double[] allf = new double[numLabels];
        int i2 = 0;
        while (i2 < numLabels) {
            Object label = dict.lookupObject(i2);
            double precision = (double)numCorrectTokens[i2] / (double)numPredTokens[i2];
            double recall = (double)numCorrectTokens[i2] / (double)numTrueTokens[i2];
            double f1 = 2.0 * precision * recall / (precision + recall);
            if (!Double.isNaN(f1)) {
                allf[i2] = f1;
            }
            logger.info(String.valueOf(description) + " label " + label + " P " + f.format(precision) + " R " + f.format(recall) + " F1 " + f.format(f1));
            ++i2;
        }
        logger.info("Macro-average F1 " + f.format(MatrixOps.mean(allf)));
    }
}

