/*
 * Decompiled with CFR 0.152.
 */
package sklearn.naive_bayes;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import sklearn.SkLearnClassifier;

public class MultinomialNB
extends SkLearnClassifier {
    public MultinomialNB(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getFeatureCountShape();
        return shape[1];
    }

    @Override
    public Model encodeModel(Schema schema) {
        int[] shape = this.getFeatureCountShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        List<Number> classLogPrior = this.getClassLogPrior();
        List featureLogProb = this.getFeatureLogProb();
        Function<Number, Number> function = new Function<Number, Number>(){

            public Number apply(Number value) {
                if (value.doubleValue() == Double.NEGATIVE_INFINITY) {
                    return null;
                }
                return value;
            }
        };
        featureLogProb = Lists.transform(featureLogProb, (Function)function);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
        for (int i = 0; i < numberOfClasses; ++i) {
            List coefficients = CMatrixUtil.getRow((List)featureLogProb, (int)numberOfClasses, (int)numberOfFeatures, (int)i);
            Number intercept = classLogPrior.get(i);
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable((List)features, (List)coefficients, (Number)intercept).setTargetCategory(categoricalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
        this.encodePredictProbaOutput((Model)regressionModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        return regressionModel;
    }

    public List<Number> getClassLogPrior() {
        return this.getNumberArray("class_log_prior_");
    }

    public int[] getFeatureCountShape() {
        return this.getArrayShape("feature_count_", 2);
    }

    public List<Number> getFeatureLogProb() {
        return this.getNumberArray("feature_log_prob_");
    }
}

