/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.MatrixUtil;
import org.jpmml.sparkml.VectorUtil;

public class LogisticRegressionModelConverter
extends ClassificationModelConverter<LogisticRegressionModel> {
    public LogisticRegressionModelConverter(LogisticRegressionModel model) {
        super(model);
    }

    public RegressionModel encodeModel(Schema schema) {
        LogisticRegressionModel model = (LogisticRegressionModel)this.getTransformer();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (categoricalLabel.size() == 2) {
            RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification((List)schema.getFeatures(), VectorUtil.toList(model.coefficients()), (Double)model.intercept(), (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)true, (Schema)schema).setOutput(null);
            return regressionModel;
        }
        if (categoricalLabel.size() > 2) {
            Matrix coefficientMatrix = model.coefficientMatrix();
            Vector interceptVector = model.interceptVector();
            List features = schema.getFeatures();
            ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                RegressionTable regressionTable = RegressionModelUtil.createRegressionTable((List)features, MatrixUtil.getRow(coefficientMatrix, i), (Double)interceptVector.apply(i)).setTargetCategory(categoricalLabel.getValue(i));
                regressionTables.add(regressionTable);
            }
            RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
            return regressionModel;
        }
        throw new IllegalArgumentException();
    }
}

