/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.util;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.maxent.MaxEntClassifier;
import org.maochen.nlp.ml.util.dataio.CSVDataReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CrossValidation {
    private static final Logger LOG = LoggerFactory.getLogger(CrossValidation.class);
    private int nfold;
    private IClassifier classifier;
    private Set<String> labels;
    public Table<Integer, String, Score> scores = HashBasedTable.create();
    private boolean shuffleData;

    public void run(List<Tuple> data) {
        int i;
        ArrayList<Tuple> dataCopy = new ArrayList<Tuple>(data);
        this.labels = data.parallelStream().map(x -> x.label).collect(Collectors.toSet());
        if (this.shuffleData) {
            Collections.shuffle(dataCopy);
        }
        int chunkSize = data.size() / this.nfold;
        int reminder = data.size() % chunkSize;
        for (i = data.size() - 1; i > data.size() - 1 - reminder; --i) {
            LOG.info("Dropping the tail id: " + data.get((int)i).id);
        }
        for (i = 0; i < this.nfold; ++i) {
            System.err.println("Cross validation round " + (i + 1) + "/" + this.nfold);
            ArrayList<Tuple> testing = new ArrayList<Tuple>(data.subList(i, i + chunkSize));
            ArrayList<Tuple> training = new ArrayList<Tuple>(data.subList(0, i));
            training.addAll(data.subList(i + chunkSize, data.size()));
            this.eval(training, testing, i);
        }
    }

    private void eval(List<Tuple> training, List<Tuple> testing, int nfold) {
        this.classifier.train(training);
        for (Tuple tuple : testing) {
            String actual = this.classifier.predict(tuple).entrySet().stream().max((e1, e2) -> ((Double)e1.getValue()).compareTo((Double)e2.getValue())).map(Map.Entry::getKey).orElse("");
            this.updateScore(tuple, actual, nfold);
        }
    }

    private void updateScore(Tuple testingTuple, String actual, int nfold) {
        this.labels.stream().filter(label -> !this.scores.contains((Object)nfold, label)).forEach(label -> {
            Score score = new Score();
            score.nfold = nfold;
            score.label = label;
            this.scores.put((Object)nfold, label, (Object)score);
        });
        Map nfoldResult = this.scores.row((Object)nfold);
        if (testingTuple.label.equals(actual)) {
            Score score2 = (Score)nfoldResult.get(testingTuple.label);
            ++score2.tp;
            nfoldResult.entrySet().stream().filter(x -> !((String)x.getKey()).equals(testingTuple.label)).map(Map.Entry::getValue).forEach(s -> ++s.tn);
        } else {
            String wrongLabel = actual;
            String correctLabel = testingTuple.label;
            ++((Score)nfoldResult.get((Object)wrongLabel)).fp;
            nfoldResult.entrySet().stream().filter(x -> !((String)x.getKey()).equals(correctLabel)).map(Map.Entry::getValue).forEach(score -> ++score.fn);
            nfoldResult.entrySet().stream().filter(x -> !((String)x.getKey()).equals(correctLabel) && !((String)x.getKey()).equals(wrongLabel)).map(Map.Entry::getValue).forEach(score -> ++score.tn);
        }
    }

    public Score getResult() {
        Score result = new Score();
        this.scores.values().stream().forEach(s -> {
            result.fn += s.fn;
            result.fp += s.fp;
            result.tn += s.tn;
            result.tp += s.tp;
        });
        result.fn /= this.scores.size();
        result.fp /= this.scores.size();
        result.tn /= this.scores.size();
        result.tp /= this.scores.size();
        return result;
    }

    public CrossValidation(int nfold, IClassifier classifier, boolean shuffleData) {
        if (nfold < 2) {
            throw new RuntimeException("CV expects n-fold greater than 1.");
        }
        this.nfold = nfold;
        this.classifier = classifier;
        this.shuffleData = shuffleData;
    }

    public static void main(String[] args) throws IOException {
        MaxEntClassifier maxEntClassifier = new MaxEntClassifier();
        Properties properties = new Properties();
        properties.put("iter", "500");
        maxEntClassifier.setParameter(properties);
        String fileName = "/Users/mguan/Desktop/train.balanced.csv";
        CSVDataReader dataReader = new CSVDataReader(fileName, -1, ",", null, -1);
        List<Tuple> data = dataReader.read();
        CrossValidation cv = new CrossValidation(10, maxEntClassifier, true);
        cv.run(data);
        Score score = cv.getResult();
        System.out.println("Precision: " + score.getPrecision());
        System.out.println("Recall: " + score.getRecall());
        System.out.println("F1: " + score.getF1());
        System.out.println("F2: " + score.getF2());
    }

    public static class Score {
        int nfold;
        String label;
        int tp = 0;
        int tn = 0;
        int fp = 0;
        int fn = 0;

        public double getF1() {
            double precision = this.getPrecision();
            double recall = this.getRecall();
            return 2.0 * precision * recall / (precision + recall);
        }

        public double getF2() {
            double precision = this.getPrecision();
            double recall = this.getRecall();
            return 5.0 * precision * recall / (4.0 * precision + recall);
        }

        public double getPrecision() {
            return (double)this.tp / (double)(this.tp + this.fp);
        }

        public double getRecall() {
            return (double)this.tp / (double)(this.tp + this.fn);
        }

        public double getAccuracy() {
            return (double)(this.tn + this.tp) / (double)(this.tp + this.tn + this.fn + this.fp);
        }

        public String toString() {
            return "P: " + String.format("%.2f", this.getPrecision()) + "\tR: " + String.format("%.2f", this.getRecall()) + "\tA: " + String.format("%.2f", this.getAccuracy()) + "\tF1: " + String.format("%.2f", this.getF1());
        }
    }
}

