/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.ml.util;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.vector.FeatNamedVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ChiSquare {
    private static final Logger LOG = LoggerFactory.getLogger(ChiSquare.class);
    public static final double EMPIRICAL_P_VALUE = 0.05;
    protected Table<String, String, Integer> dataTable = HashBasedTable.create();
    protected Table<String, String, Double> chiSquareTable = HashBasedTable.create();
    protected int df;
    protected int total;
    public double totalChiSquare;
    public double totalPVal;

    public void loadTrainingData(List<Tuple> trainingData) {
        for (int i = 0; i < trainingData.size(); ++i) {
            if (i % 1000 == 0) {
                LOG.debug("Processed " + i + " of " + trainingData.size());
            }
            Tuple t = trainingData.get(i);
            for (String featName : ((FeatNamedVector)t.vector).featsName) {
                Integer count = (Integer)this.dataTable.get((Object)featName, (Object)t.label);
                count = count == null ? 1 : count + 1;
                this.dataTable.put((Object)featName, (Object)t.label, (Object)count);
            }
        }
    }

    public void calculateChiSquare() {
        this.df = this.dataTable.rowKeySet().size() - 1;
        this.df = this.df == 0 ? (this.df = 1) : this.df;
        this.df *= this.dataTable.columnKeySet().size() - 1;
        this.df = this.df == 0 ? (this.df = 1) : this.df;
        this.total = this.dataTable.rowMap().values().stream().map(Map::values).map(lst -> lst.stream().mapToInt(num -> num).sum()).mapToInt(num -> num).sum();
        this.dataTable.cellSet().forEach(cell -> {
            String feat = (String)cell.getRowKey();
            String label = (String)cell.getColumnKey();
            int count = cell.getValue() == null ? 0 : (Integer)cell.getValue();
            int c_feat = this.dataTable.row((Object)feat).values().stream().mapToInt(x -> x).sum();
            int c_label = this.dataTable.column((Object)label).values().stream().mapToInt(x -> x).sum();
            double e_xi_yi = (double)(c_feat * c_label) / (double)this.total;
            this.chiSquareTable.put((Object)feat, (Object)label, (Object)(Math.pow((double)count - e_xi_yi, 2.0) / e_xi_yi));
        });
        this.totalChiSquare = this.chiSquareTable.cellSet().parallelStream().mapToDouble(cell -> cell.getValue() == null ? 0.0 : (Double)cell.getValue()).sum();
        this.totalPVal = ChiSquare.getPValue(this.totalChiSquare, this.df);
    }

    protected static double getPValue(double chiSquare, double df) {
        GammaDistribution gamma = new GammaDistribution(df / 2.0, 2.0);
        double gammaVal = gamma.cumulativeProbability(chiSquare);
        return 1.0 - gammaVal;
    }

    public void printPTable() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("Greater than 0.05 might be independent.").append(System.lineSeparator());
        stringBuilder.append("Total P Value: ").append(String.format("%.5f", this.totalPVal)).append(System.lineSeparator());
        System.out.println(stringBuilder.toString());
    }
}

