/*
 * Decompiled with CFR 0.152.
 */
package org.maochen.nlp.classifier.naivebayes;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.maochen.nlp.classifier.naivebayes.NaiveBayesModel;
import org.maochen.nlp.datastructure.LabelIndexer;
import org.maochen.nlp.datastructure.Tuple;
import org.maochen.nlp.utils.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

final class NBTrainingEngine {
    private static final Logger LOG = LoggerFactory.getLogger(NBTrainingEngine.class);
    private List<Tuple> trainingData;
    private NaiveBayesModel model;
    private int[] count;

    private void calculateMean() {
        int i;
        for (Tuple t : this.trainingData) {
            int index;
            int n = index = this.model.labelIndexer.getIndex(t.label);
            this.count[n] = this.count[n] + 1;
            this.model.meanVectors[index] = VectorUtils.zip(this.model.meanVectors[index], t.featureVector, (x, y) -> x + y);
        }
        for (i = 0; i < this.model.meanVectors.length; ++i) {
            double[] meanVector = this.model.meanVectors[i];
            meanVector = VectorUtils.scale(meanVector, 1.0 / (double)this.count[i]);
            this.model.meanVectors[i] = meanVector;
        }
        for (i = 0; i < this.model.meanVectors.length; ++i) {
            for (int j = 0; j < this.model.meanVectors[i].length; ++j) {
                if (this.model.meanVectors[i][j] != 0.0) continue;
                LOG.warn("mean is 0 for label " + (String)this.model.labelIndexer.labelIndexer.inverse().get((Object)i) + " at dimension " + j);
                this.model.meanVectors[i][j] = Double.MIN_VALUE;
            }
        }
    }

    private void calculateVariance() {
        int i;
        for (Tuple t : this.trainingData) {
            int index = this.model.labelIndexer.getIndex(t.label);
            double[] diff = VectorUtils.zip(t.featureVector, this.model.meanVectors[index], (x, y) -> x - y);
            diff = Arrays.stream(diff).map(x -> x * x).toArray();
            double[] varianceVector = VectorUtils.zip(this.model.varianceVectors[index], diff, (x, y) -> x + y);
            this.model.varianceVectors[index] = varianceVector;
        }
        for (i = 0; i < this.model.varianceVectors.length; ++i) {
            double[] varianceVector = this.model.varianceVectors[i];
            varianceVector = VectorUtils.scale(varianceVector, 1.0 / (double)(this.count[i] - 1));
            this.model.varianceVectors[i] = varianceVector;
        }
        for (i = 0; i < this.model.varianceVectors.length; ++i) {
            for (int j = 0; j < this.model.varianceVectors[i].length; ++j) {
                if (this.model.varianceVectors[i][j] != 0.0) continue;
                LOG.warn("variance is 0 for label " + (String)this.model.labelIndexer.labelIndexer.inverse().get((Object)i) + " at dimension " + j);
                this.model.varianceVectors[i][j] = Double.MIN_VALUE;
            }
        }
    }

    public void calculateLabelPrior() {
        double prior = 1.0 / (double)this.model.labelIndexer.getLabelSize();
        this.model.labelIndexer.getIndexSet().forEach(labelIndex -> this.model.labelPrior.put((Integer)labelIndex, prior));
    }

    public NaiveBayesModel train() {
        this.calculateMean();
        this.calculateVariance();
        this.calculateLabelPrior();
        return this.model;
    }

    public NBTrainingEngine(List<Tuple> trainingData) {
        this.trainingData = trainingData;
        this.model = new NaiveBayesModel();
        this.model.labelIndexer = new LabelIndexer(trainingData);
        int vectorLength = trainingData.stream().findFirst().map(x -> x.featureVector.length).orElse(0);
        this.count = new int[this.model.labelIndexer.getLabelSize()];
        this.model.meanVectors = new double[this.model.labelIndexer.getLabelSize()][vectorLength];
        this.model.varianceVectors = new double[this.model.labelIndexer.getLabelSize()][vectorLength];
        this.model.labelPrior = new HashMap<Integer, Double>();
        for (int i = 0; i < this.model.labelIndexer.getLabelSize(); ++i) {
            this.model.meanVectors[i] = new double[vectorLength];
            this.model.varianceVectors[i] = new double[vectorLength];
        }
    }
}

