/*
 * Decompiled with CFR 0.152.
 */
package sklearn.linear_model;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SigmoidTransformation;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import sklearn.Classifier;
import sklearn.EstimatorUtil;
import sklearn.linear_model.BaseLinearUtil;

public abstract class BaseLinearClassifier
extends Classifier {
    public BaseLinearClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getCoefShape();
        return shape[1];
    }

    public MiningModel encodeModel(Schema schema) {
        int[] shape = this.getCoefShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        boolean hasProbabilityDistribution = this.hasProbabilityDistribution();
        List<? extends Number> coefficients = this.getCoef();
        List<? extends Number> intercepts = this.getIntercept();
        Schema segmentSchema = new Schema((Label)new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (numberOfClasses == 1) {
            EstimatorUtil.checkSize(2, categoricalLabel);
            RegressionModel regressionModel = BaseLinearUtil.encodeRegressionModel(intercepts.get(0), CMatrixUtil.getRow(coefficients, (int)numberOfClasses, (int)numberOfFeatures, (int)0), segmentSchema).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldName.create((String)("decisionFunction_" + categoricalLabel.getValue(1))), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            return MiningModelUtil.createBinaryLogisticClassification((Schema)schema, (Model)regressionModel, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SOFTMAX, (double)0.0, (double)1.0, (boolean)hasProbabilityDistribution);
        }
        if (numberOfClasses >= 3) {
            EstimatorUtil.checkSize(numberOfClasses, categoricalLabel);
            ArrayList<RegressionModel> regressionModels = new ArrayList<RegressionModel>();
            int rows = categoricalLabel.size();
            for (int i = 0; i < rows; ++i) {
                RegressionModel regressionModel = BaseLinearUtil.encodeRegressionModel(intercepts.get(i), CMatrixUtil.getRow(coefficients, (int)numberOfClasses, (int)numberOfFeatures, (int)i), segmentSchema).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldName.create((String)("decisionFunction_" + categoricalLabel.getValue(i))), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[]{new SigmoidTransformation((Number)-1.0)}));
                regressionModels.add(regressionModel);
            }
            return MiningModelUtil.createClassification((Schema)schema, regressionModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)hasProbabilityDistribution);
        }
        throw new IllegalArgumentException();
    }

    public List<? extends Number> getCoef() {
        return ClassDictUtil.getArray(this, "coef_");
    }

    public List<? extends Number> getIntercept() {
        return ClassDictUtil.getArray(this, "intercept_");
    }

    private int[] getCoefShape() {
        return ClassDictUtil.getShape(this, "coef_", 2);
    }
}

