/*
 * Decompiled with CFR 0.152.
 */
package sklearn.naive_bayes;

import java.util.List;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DataType;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.BayesOutput;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueCounts;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.naive_bayes.TargetValueStats;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Classifier;

public class GaussianNB
extends Classifier {
    public GaussianNB(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getThetaShape();
        return shape[1];
    }

    public NaiveBayesModel encodeModel(Schema schema) {
        int[] shape = this.getThetaShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        List<? extends Number> theta = this.getTheta();
        List<? extends Number> sigma = this.getSigma();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        BayesInputs bayesInputs = new BayesInputs();
        for (int i = 0; i < numberOfFeatures; ++i) {
            Feature feature = schema.getFeature(i);
            List means = CMatrixUtil.getColumn(theta, (int)numberOfClasses, (int)numberOfFeatures, (int)i);
            List variances = CMatrixUtil.getColumn(sigma, (int)numberOfClasses, (int)numberOfFeatures, (int)i);
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            BayesInput bayesInput = new BayesInput(continuousFeature.getName()).setTargetValueStats(GaussianNB.encodeTargetValueStats(categoricalLabel.getValues(), means, variances));
            bayesInputs.addBayesInputs(new BayesInput[]{bayesInput});
        }
        List<Integer> classCount = this.getClassCount();
        BayesOutput bayesOutput = new BayesOutput(categoricalLabel.getName(), null).setTargetValueCounts(GaussianNB.encodeTargetValueCounts(categoricalLabel.getValues(), classCount));
        NaiveBayesModel naiveBayesModel = new NaiveBayesModel(0.0, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), bayesInputs, bayesOutput).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
        return naiveBayesModel;
    }

    public List<Integer> getClassCount() {
        return ValueUtil.asIntegers(ClassDictUtil.getArray(this, "class_count_"));
    }

    public List<? extends Number> getTheta() {
        return ClassDictUtil.getArray(this, "theta_");
    }

    public List<? extends Number> getSigma() {
        return ClassDictUtil.getArray(this, "sigma_");
    }

    private int[] getThetaShape() {
        return ClassDictUtil.getShape(this, "theta_", 2);
    }

    private static TargetValueStats encodeTargetValueStats(List<String> values, List<? extends Number> means, List<? extends Number> variances) {
        TargetValueStats targetValueStats = new TargetValueStats();
        ClassDictUtil.checkSize(values, means, variances);
        for (int i = 0; i < values.size(); ++i) {
            GaussianDistribution gaussianDistribution = new GaussianDistribution(ValueUtil.asDouble((Number)means.get(i)).doubleValue(), ValueUtil.asDouble((Number)variances.get(i)).doubleValue());
            TargetValueStat targetValueStat = new TargetValueStat(values.get(i)).setContinuousDistribution((ContinuousDistribution)gaussianDistribution);
            targetValueStats.addTargetValueStats(new TargetValueStat[]{targetValueStat});
        }
        return targetValueStats;
    }

    private static TargetValueCounts encodeTargetValueCounts(List<String> values, List<Integer> counts) {
        TargetValueCounts targetValueCounts = new TargetValueCounts();
        ClassDictUtil.checkSize(values, counts);
        for (int i = 0; i < values.size(); ++i) {
            TargetValueCount targetValueCount = new TargetValueCount(values.get(i), (double)counts.get(i).intValue());
            targetValueCounts.addTargetValueCounts(new TargetValueCount[]{targetValueCount});
        }
        return targetValueCounts;
    }
}

