/*
 * Decompiled with CFR 0.152.
 */
package sklearn.naive_bayes;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.PairCounts;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import sklearn.naive_bayes.DiscreteNB;
import sklearn.naive_bayes.DiscreteNBUtil;

public class BernoulliNB
extends DiscreteNB {
    public BernoulliNB(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getFeatureCountShape();
        return shape[1];
    }

    @Override
    public BayesInputs encodeBayesInputs(List<?> values, List<? extends Feature> features) {
        int[] shape = this.getFeatureCountShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        Number alpha = this.getAlpha();
        List<Integer> classCount = this.getClassCount();
        List<Integer> featureCount = this.getFeatureCount();
        BayesInputs bayesInputs = new BayesInputs();
        for (int i = 0; i < numberOfFeatures; ++i) {
            ArrayList<Number> eventCounts;
            ArrayList<Integer> nonEventCounts;
            List featureValues;
            CategoricalFeature categoricalFeature;
            Feature feature = features.get(i);
            List featureClassCount = CMatrixUtil.getColumn(featureCount, (int)numberOfClasses, (int)numberOfFeatures, (int)i);
            if (feature instanceof CategoricalFeature) {
                categoricalFeature = (CategoricalFeature)feature;
                SchemaUtil.checkSize((int)2, (CategoricalFeature)categoricalFeature);
                featureValues = categoricalFeature.getValues();
                nonEventCounts = new ArrayList<Integer>();
                eventCounts = new ArrayList<Number>();
                for (int j = 0; j < numberOfClasses; ++j) {
                    Integer nonEventCount = classCount.get(j) - (Integer)featureClassCount.get(j);
                    Number eventCount = (Number)featureClassCount.get(j);
                    nonEventCounts.add(nonEventCount);
                    eventCounts.add(eventCount);
                }
            } else {
                throw new IllegalArgumentException("Expected a categorical feature, got " + feature);
            }
            ArrayList<PairCounts> pairCounts = new ArrayList<PairCounts>();
            pairCounts.add(DiscreteNBUtil.encodePairCounts(featureValues.get(0), values, alpha, nonEventCounts));
            pairCounts.add(DiscreteNBUtil.encodePairCounts(featureValues.get(1), values, alpha, eventCounts));
            BayesInput bayesInput = new BayesInput(categoricalFeature.getName(), null, pairCounts);
            bayesInputs.addBayesInputs(new BayesInput[]{bayesInput});
        }
        return bayesInputs;
    }

    @Override
    public NaiveBayesModel encodeModel(Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        SchemaUtil.checkSize((int)2, (DiscreteLabel)categoricalLabel);
        return super.encodeModel(schema);
    }

    public List<Integer> getFeatureCount() {
        return this.getIntegerArray("feature_count_");
    }

    public int[] getFeatureCountShape() {
        return this.getArrayShape("feature_count_", 2);
    }
}

