/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.learning;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFEvaluator;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelsSequence;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.text.DecimalFormat;
import java.util.List;
import java.util.logging.Logger;

public class MultiSegmentationEvaluatorACRF
extends ACRFEvaluator {
    private static Logger logger = MalletLogger.getLogger(MultiSegmentationEvaluatorACRF.class.getName());
    Object[] segmentStartTags;
    Object[] segmentContinueTags;
    Object[] segmentStartOrContinueTags;
    private int evalIterations = 0;
    private int slice = 0;

    public MultiSegmentationEvaluatorACRF(Object[] segmentStartTags, Object[] segmentContinueTags, boolean showViterbi) {
        this.segmentStartTags = segmentStartTags;
        this.segmentContinueTags = segmentContinueTags;
        assert (segmentStartTags.length == segmentContinueTags.length);
    }

    public MultiSegmentationEvaluatorACRF(Object[] segmentStartTags, Object[] segmentContinueTags) {
        this(segmentStartTags, segmentContinueTags, true);
    }

    public MultiSegmentationEvaluatorACRF(Object[] segmentStartTags, Object[] segmentContinueTags, int slice) {
        this(segmentStartTags, segmentContinueTags, true);
        this.slice = slice;
    }

    private LabelSequence slice(LabelsSequence lseq, int k) {
        Label[] arr = new Label[lseq.size()];
        for (int i = 0; i < lseq.size(); ++i) {
            arr[i] = lseq.getLabels(i).get(k);
        }
        return new LabelSequence(arr);
    }

    @Override
    public boolean evaluate(ACRF acrf, int iter, InstanceList training, InstanceList validation, InstanceList testing) {
        if (!this.shouldDoEvaluate(iter)) {
            return true;
        }
        InstanceList[] lists = new InstanceList[]{training, validation, testing};
        String[] listnames = new String[]{"Training", "Validation", "Testing"};
        for (int k = 0; k < lists.length; ++k) {
            if (lists[k] == null) continue;
            this.test(acrf, lists[k], listnames[k]);
        }
        return true;
    }

    @Override
    public void test(InstanceList gold, List returned, String description) {
        TestResults results = new TestResults(this.segmentStartTags, this.segmentContinueTags);
        for (int i = 0; i < gold.size(); ++i) {
            Instance instance = (Instance)gold.get(i);
            Sequence trueOutput = this.processTrueOutput((Sequence)instance.getTarget());
            LabelSequence predOutput = this.slice((LabelsSequence)returned.get(i), this.slice);
            assert (predOutput.size() == trueOutput.size());
            results.incrementCounts(trueOutput, predOutput);
        }
        results.logResults(description);
    }

    private Sequence processTrueOutput(Sequence sequence) {
        if (sequence instanceof LabelsSequence) {
            LabelsSequence lseq = (LabelsSequence)sequence;
            return this.slice(lseq, this.slice);
        }
        return sequence;
    }

    public static class TestResults {
        private Object[] segmentStartTags;
        private Object[] segmentContinueTags;
        private int numCorrectTokens;
        private int totalTokens;
        private int[] numTrueSegments;
        private int[] numPredictedSegments;
        private int[] numCorrectSegments;
        private int allIndex;

        public TestResults(Object[] segmentStartTags, Object[] segmentContinueTags) {
            this.segmentStartTags = segmentStartTags;
            this.segmentContinueTags = segmentContinueTags;
            this.allIndex = segmentStartTags.length;
            this.numTrueSegments = new int[this.allIndex + 1];
            this.numPredictedSegments = new int[this.allIndex + 1];
            this.numCorrectSegments = new int[this.allIndex + 1];
            Object sourceTokenSequence = null;
            this.numCorrectTokens = 0;
            this.totalTokens = 0;
            for (int n = 0; n < this.numTrueSegments.length; ++n) {
                this.numCorrectSegments[n] = 0;
                this.numPredictedSegments[n] = 0;
                this.numTrueSegments[n] = 0;
            }
        }

        public void logResults(String description) {
            DecimalFormat f = new DecimalFormat("0.####");
            logger.info(description + " tokenaccuracy=" + f.format((double)this.numCorrectTokens / (double)this.totalTokens));
            for (int n = 0; n < this.numCorrectSegments.length; ++n) {
                logger.info((n < this.allIndex ? this.segmentStartTags[n].toString() : "OVERALL") + ' ');
                double precision = this.numPredictedSegments[n] == 0 ? 1.0 : (double)this.numCorrectSegments[n] / (double)this.numPredictedSegments[n];
                double recall = this.numTrueSegments[n] == 0 ? 1.0 : (double)this.numCorrectSegments[n] / (double)this.numTrueSegments[n];
                double f1 = recall + precision == 0.0 ? 0.0 : 2.0 * recall * precision / (recall + precision);
                logger.info(" segments true=" + this.numTrueSegments[n] + " pred=" + this.numPredictedSegments[n] + " correct=" + this.numCorrectSegments[n] + " misses=" + (this.numTrueSegments[n] - this.numCorrectSegments[n]) + " alarms=" + (this.numPredictedSegments[n] - this.numCorrectSegments[n]));
                logger.info(" precision=" + f.format(precision) + " recall=" + f.format(recall) + " f1=" + f.format(f1));
            }
        }

        public void incrementCounts(Sequence trueOutput, Sequence predOutput) {
            for (int j = 0; j < trueOutput.size(); ++j) {
                int m;
                int n;
                String predToken;
                ++this.totalTokens;
                String trueToken = trueOutput.get(j).toString();
                if (trueToken.equals(predToken = predOutput.get(j).toString())) {
                    ++this.numCorrectTokens;
                }
                int predStart = -1;
                int trueStart = -1;
                for (n = 0; n < this.segmentStartTags.length; ++n) {
                    if (!this.segmentStartTags[n].equals(trueToken)) continue;
                    int n2 = n;
                    this.numTrueSegments[n2] = this.numTrueSegments[n2] + 1;
                    int n3 = this.allIndex;
                    this.numTrueSegments[n3] = this.numTrueSegments[n3] + 1;
                    trueStart = n;
                    break;
                }
                for (n = 0; n < this.segmentStartTags.length; ++n) {
                    if (!this.segmentStartTags[n].equals(predOutput.get(j))) continue;
                    int n4 = n;
                    this.numPredictedSegments[n4] = this.numPredictedSegments[n4] + 1;
                    int n5 = this.allIndex;
                    this.numPredictedSegments[n5] = this.numPredictedSegments[n5] + 1;
                    predStart = n;
                }
                if (trueStart == -1 || trueStart != predStart) continue;
                boolean trueContinue = false;
                boolean predContinue = false;
                for (m = j + 1; m < trueOutput.size(); ++m) {
                    String trueTokenCtd = trueOutput.get(m).toString();
                    String predTokenCtd = predOutput.get(m).toString();
                    trueContinue = this.segmentContinueTags[predStart].equals(trueTokenCtd);
                    predContinue = this.segmentContinueTags[predStart].equals(predTokenCtd);
                    if (trueContinue && predContinue) continue;
                    if (trueContinue != predContinue) break;
                    int n6 = predStart;
                    this.numCorrectSegments[n6] = this.numCorrectSegments[n6] + 1;
                    int n7 = this.allIndex;
                    this.numCorrectSegments[n7] = this.numCorrectSegments[n7] + 1;
                    break;
                }
                if (m != trueOutput.size() || trueContinue != predContinue) continue;
                int n8 = predStart;
                this.numCorrectSegments[n8] = this.numCorrectSegments[n8] + 1;
                int n9 = this.allIndex;
                this.numCorrectSegments[n9] = this.numCorrectSegments[n9] + 1;
            }
        }
    }
}

