/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CrossValidation {
    private static final Logger LOG = LoggerFactory.getLogger(CrossValidation.class);
    private int round;
    private IClassifier classifier;
    private Set<String> labels;
    private Set<Score> scores = new HashSet<Score>();
    private boolean shuffledata;

    public void run(List<Tuple> data) {
        int i;
        ArrayList<Tuple> data1 = new ArrayList<Tuple>(data);
        this.labels = data.parallelStream().map(x -> x.label).collect(Collectors.toSet());
        if (this.shuffledata) {
            Collections.shuffle(data1);
        }
        int chunkSize = data.size() / this.round;
        int reminder = data.size() % chunkSize;
        for (i = data.size() - 1; i > data.size() - 1 - reminder; --i) {
            LOG.info("Dropping the tail id: " + data.get((int)i).id);
        }
        for (i = 0; i < this.round; ++i) {
            List<Tuple> testing = data.subList(i, i + chunkSize);
            List<Tuple> training = data.subList(0, i);
            training.addAll(data.subList(i + chunkSize, data.size()));
            this.eval(training, testing, i);
        }
    }

    private void eval(List<Tuple> training, List<Tuple> testing, int round) {
        this.classifier.train(training);
        for (Tuple tuple : testing) {
            String actual = this.classifier.predict(tuple).entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).map(Map.Entry::getKey).orElse("");
            this.updateScore(tuple, actual, round);
        }
    }

    private void updateScore(Tuple testingTuple, String actual, int round) {
        this.labels.stream().forEach(label -> {
            Score score = new Score();
            score.round = round;
            score.label = label;
            this.scores.add(score);
        });
        if (testingTuple.label.equals(actual)) {
            this.scores.stream().filter(x -> x.round == round).filter(x -> x.label.equals(testingTuple.label)).forEach(score -> ++score.truePos);
            this.scores.stream().filter(x -> x.round == round).filter(x -> !x.label.equals(testingTuple.label)).forEach(score -> ++score.trueNeg);
        } else {
            String wrongLabel = actual;
            String correctLabel = testingTuple.label;
            this.scores.stream().filter(x -> x.round == round).filter(x -> x.label.equals(wrongLabel)).forEach(score -> ++score.falsePos);
            this.scores.stream().filter(x -> x.round == round).filter(x -> !x.label.equals(correctLabel)).forEach(score -> ++score.falseNeg);
            this.scores.stream().filter(x -> x.round == round).filter(x -> !x.label.equals(correctLabel) && !x.label.equals(wrongLabel)).forEach(score -> ++score.trueNeg);
        }
    }

    public CrossValidation(int round, IClassifier classifier, boolean shuffleData) {
        this.round = round;
        this.classifier = classifier;
        this.shuffledata = shuffleData;
    }

    static class Score {
        int round;
        String label;
        int truePos = 0;
        int trueNeg = 0;
        int falsePos = 0;
        int falseNeg = 0;

        Score() {
        }

        public double getF1() {
            double precision = this.getPrecision();
            double recall = this.getRecall();
            return 2.0 * precision * recall / (precision + recall);
        }

        public double getPrecision() {
            return (double)this.truePos / (double)(this.truePos + this.falsePos);
        }

        public double getRecall() {
            return (double)this.truePos / (double)(this.truePos + this.falseNeg);
        }

        public double getAccurancy() {
            return (double)(this.trueNeg + this.truePos) / (double)(this.truePos + this.trueNeg + this.falseNeg + this.falsePos);
        }
    }
}

