/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.List;
import java.util.NoSuchElementException;
import org.apache.spark.ml.classification.NaiveBayesModel;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.VectorUtil;
import scala.collection.Iterator;

public class NaiveBayesModelConverter
extends ClassificationModelConverter<NaiveBayesModel> {
    public NaiveBayesModelConverter(NaiveBayesModel model) {
        super(model);
    }

    public RegressionModel encodeModel(Schema schema) {
        String modelType;
        NaiveBayesModel model = (NaiveBayesModel)this.getTransformer();
        switch (modelType = model.getModelType()) {
            case "multinomial": {
                break;
            }
            default: {
                throw new IllegalArgumentException(modelType);
            }
        }
        try {
            double[] thresholds = model.getThresholds();
            for (int i = 0; i < thresholds.length; ++i) {
                double threshold = thresholds[i];
                if (threshold == 0.0) continue;
                throw new IllegalArgumentException();
            }
        }
        catch (NoSuchElementException thresholds) {
            // empty catch block
        }
        Vector pi = model.pi();
        Matrix theta = model.theta();
        List<Double> intercepts = VectorUtil.toList(pi);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        Iterator thetaRows = theta.rowIter();
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), null).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            List<Double> coefficients = VectorUtil.toList((Vector)thetaRows.next());
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable((List)features, coefficients, (Double)intercepts.get(i)).setTargetCategory(categoricalLabel.getValue(i));
            regressionModel.addRegressionTables(new RegressionTable[]{regressionTable});
        }
        return regressionModel;
    }
}

