/*
 * Decompiled with CFR 0.152.
 */
package sklearn.naive_bayes;

import java.util.List;
import org.dmg.pmml.BayesInput;
import org.dmg.pmml.BayesInputs;
import org.dmg.pmml.BayesOutput;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NaiveBayesModel;
import org.dmg.pmml.TargetValueCount;
import org.dmg.pmml.TargetValueCounts;
import org.dmg.pmml.TargetValueStat;
import org.dmg.pmml.TargetValueStats;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.MatrixUtil;
import sklearn.Classifier;

public class GaussianNB
extends Classifier {
    public GaussianNB(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getThetaShape();
        return shape[1];
    }

    public NaiveBayesModel encodeModel(Schema schema) {
        int[] shape = this.getThetaShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        List targetCategories = schema.getTargetCategories();
        List<? extends Number> theta = this.getTheta();
        List<? extends Number> sigma = this.getSigma();
        BayesInputs bayesInputs = new BayesInputs();
        for (int i = 0; i < numberOfFeatures; ++i) {
            Feature feature = schema.getFeature(i);
            List<? extends Number> means = MatrixUtil.getColumn(theta, numberOfClasses, numberOfFeatures, i);
            List<? extends Number> variances = MatrixUtil.getColumn(sigma, numberOfClasses, numberOfFeatures, i);
            BayesInput bayesInput = new BayesInput(feature.getName()).setTargetValueStats(GaussianNB.encodeTargetValueStats(targetCategories, means, variances));
            bayesInputs.addBayesInputs(new BayesInput[]{bayesInput});
        }
        FieldName targetField = schema.getTargetField();
        if (targetField == null) {
            throw new IllegalArgumentException();
        }
        List<Integer> classCount = this.getClassCount();
        BayesOutput bayesOutput = new BayesOutput(targetField, null).setTargetValueCounts(GaussianNB.encodeTargetValueCounts(targetCategories, classCount));
        NaiveBayesModel naiveBayesModel = new NaiveBayesModel(0.0, MiningFunctionType.CLASSIFICATION, ModelUtil.createMiningSchema((Schema)schema), bayesInputs, bayesOutput).setOutput(ModelUtil.createProbabilityOutput((Schema)schema));
        return naiveBayesModel;
    }

    public List<Integer> getClassCount() {
        return ValueUtil.asIntegers(ClassDictUtil.getArray(this, "class_count_"));
    }

    public List<? extends Number> getTheta() {
        return ClassDictUtil.getArray(this, "theta_");
    }

    public List<? extends Number> getSigma() {
        return ClassDictUtil.getArray(this, "sigma_");
    }

    private int[] getThetaShape() {
        return ClassDictUtil.getShape(this, "theta_", 2);
    }

    private static TargetValueStats encodeTargetValueStats(List<String> targetCategories, List<? extends Number> means, List<? extends Number> variances) {
        if (targetCategories.size() != means.size() || targetCategories.size() != variances.size()) {
            throw new IllegalArgumentException();
        }
        TargetValueStats targetValueStats = new TargetValueStats();
        for (int i = 0; i < targetCategories.size(); ++i) {
            GaussianDistribution gaussianDistribution = new GaussianDistribution(ValueUtil.asDouble((Number)means.get(i)).doubleValue(), ValueUtil.asDouble((Number)variances.get(i)).doubleValue());
            TargetValueStat targetValueStat = new TargetValueStat(targetCategories.get(i)).setContinuousDistribution((ContinuousDistribution)gaussianDistribution);
            targetValueStats.addTargetValueStats(new TargetValueStat[]{targetValueStat});
        }
        return targetValueStats;
    }

    public static TargetValueCounts encodeTargetValueCounts(List<String> targetCategories, List<Integer> counts) {
        if (targetCategories.size() != counts.size()) {
            throw new IllegalArgumentException();
        }
        TargetValueCounts targetValueCounts = new TargetValueCounts();
        for (int i = 0; i < targetCategories.size(); ++i) {
            TargetValueCount targetValueCount = new TargetValueCount(targetCategories.get(i), (double)counts.get(i).intValue());
            targetValueCounts.addTargetValueCounts(new TargetValueCount[]{targetValueCount});
        }
        return targetValueCounts;
    }
}

