/*
 * Decompiled with CFR 0.152.
 */
package sklearn.linear_model.logistic;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import sklearn.ClassifierUtil;
import sklearn.linear_model.LinearClassifier;

public class LogisticRegression
extends LinearClassifier {
    public LogisticRegression(String module, String name) {
        super(module, name);
    }

    @Override
    public Model encodeModel(Schema schema) {
        String multiClass = this.getMultiClass();
        if ("multinomial".equals(multiClass)) {
            return this.encodeMultinomialModel(schema);
        }
        if ("ovr".equals(multiClass)) {
            return this.encodeOvRModel(schema);
        }
        throw new IllegalArgumentException(multiClass);
    }

    private Model encodeMultinomialModel(Schema schema) {
        int[] shape = this.getCoefShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        List<? extends Number> coef = this.getCoef();
        List<? extends Number> intercepts = this.getIntercept();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        if (numberOfClasses == 1) {
            return this.encodeOvRModel(schema);
        }
        if (numberOfClasses >= 3) {
            ClassifierUtil.checkSize(numberOfClasses, categoricalLabel);
            ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                RegressionTable regressionTable = RegressionModelUtil.createRegressionTable((List)features, (List)ValueUtil.asDoubles((List)CMatrixUtil.getRow(coef, (int)numberOfClasses, (int)numberOfFeatures, (int)i)), (Double)ValueUtil.asDouble((Number)intercepts.get(i))).setTargetCategory(categoricalLabel.getValue(i));
                regressionTables.add(regressionTable);
            }
            RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
            return regressionModel;
        }
        throw new IllegalArgumentException();
    }

    private Model encodeOvRModel(Schema schema) {
        return super.encodeModel(schema);
    }

    public String getMultiClass() {
        return (String)this.get("multi_class");
    }
}

