/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CrossValidation {
    private static final Logger LOG = LoggerFactory.getLogger(CrossValidation.class);
    private int nfold;
    private IClassifier classifier;
    private Set<String> labels;
    private Set<Score> scores = new HashSet<Score>();
    private boolean shuffleData;

    public void run(List<Tuple> data) {
        int i;
        ArrayList<Tuple> dataCopy = new ArrayList<Tuple>(data);
        this.labels = data.parallelStream().map(x -> x.label).collect(Collectors.toSet());
        if (this.shuffleData) {
            Collections.shuffle(dataCopy);
        }
        int chunkSize = data.size() / this.nfold;
        int reminder = data.size() % chunkSize;
        for (i = data.size() - 1; i > data.size() - 1 - reminder; --i) {
            LOG.info("Dropping the tail id: " + data.get((int)i).id);
        }
        for (i = 0; i < this.nfold; ++i) {
            List<Tuple> testing = data.subList(i, i + chunkSize);
            List<Tuple> training = data.subList(0, i);
            training.addAll(data.subList(i + chunkSize, data.size()));
            this.eval(training, testing, i);
        }
    }

    private void eval(List<Tuple> training, List<Tuple> testing, int nfold) {
        this.classifier.train(training);
        for (Tuple tuple : testing) {
            String actual = this.classifier.predict(tuple).entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).map(Map.Entry::getKey).orElse("");
            this.updateScore(tuple, actual, nfold);
        }
    }

    private void updateScore(Tuple testingTuple, String actual, int nfold) {
        this.labels.stream().forEach(label -> {
            Score score = new Score();
            score.nfold = nfold;
            score.label = label;
            this.scores.add(score);
        });
        if (testingTuple.label.equals(actual)) {
            this.scores.stream().filter(x -> x.nfold == nfold).filter(x -> x.label.equals(testingTuple.label)).forEach(score -> ++score.tp);
            this.scores.stream().filter(x -> x.nfold == nfold).filter(x -> !x.label.equals(testingTuple.label)).forEach(score -> ++score.tn);
        } else {
            String wrongLabel = actual;
            String correctLabel = testingTuple.label;
            this.scores.stream().filter(x -> x.nfold == nfold).filter(x -> x.label.equals(wrongLabel)).forEach(score -> ++score.fp);
            this.scores.stream().filter(x -> x.nfold == nfold).filter(x -> !x.label.equals(correctLabel)).forEach(score -> ++score.fn);
            this.scores.stream().filter(x -> x.nfold == nfold).filter(x -> !x.label.equals(correctLabel) && !x.label.equals(wrongLabel)).forEach(score -> ++score.tn);
        }
    }

    public CrossValidation(int nfold, IClassifier classifier, boolean shuffleData) {
        this.nfold = nfold;
        this.classifier = classifier;
        this.shuffleData = shuffleData;
    }

    static class Score {
        int nfold;
        String label;
        int tp = 0;
        int tn = 0;
        int fp = 0;
        int fn = 0;

        Score() {
        }

        public double getF1() {
            double precision = this.getPrecision();
            double recall = this.getRecall();
            return 2.0 * precision * recall / (precision + recall);
        }

        public double getPrecision() {
            return (double)this.tp / (double)(this.tp + this.fp);
        }

        public double getRecall() {
            return (double)this.tp / (double)(this.tp + this.fn);
        }

        public double getAccuracy() {
            return (double)(this.tn + this.tp) / (double)(this.tp + this.tn + this.fn + this.fp);
        }
    }
}

