/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.model.HasRegressionTableOptions;
import org.jpmml.sparkml.model.LinearModelUtil;

public class LogisticRegressionModelConverter
extends ClassificationModelConverter<LogisticRegressionModel>
implements HasRegressionTableOptions {
    public LogisticRegressionModelConverter(LogisticRegressionModel model) {
        super(model);
    }

    public RegressionModel encodeModel(Schema schema) {
        LogisticRegressionModel model = (LogisticRegressionModel)this.getTransformer();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (categoricalLabel.size() == 2) {
            RegressionModel regressionModel = LinearModelUtil.createBinaryLogisticClassification(this, model.coefficients(), model.intercept(), schema).setOutput(null);
            return regressionModel;
        }
        if (categoricalLabel.size() > 2) {
            RegressionModel regressionModel = LinearModelUtil.createSoftmaxClassification(this, model.coefficientMatrix(), model.interceptVector(), schema).setOutput(null);
            return regressionModel;
        }
        throw new IllegalArgumentException();
    }
}

