/*
 * Decompiled with CFR 0.152.
 */
package sklearn.linear_model.logistic;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import net.razorvine.pickle.objects.ClassDict;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.AttributeException;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.PythonFormatterUtil;
import sklearn.VersionUtil;
import sklearn.linear_model.LinearClassifier;

public class LogisticRegression
extends LinearClassifier {
    private static final String MULTICLASS_AUTO = "auto";
    private static final String MULTICLASS_DEPRECATED = "deprecated";
    private static final String MULTICLASS_MULTINOMIAL = "multinomial";
    private static final String MULTICLASS_OVR = "ovr";
    private static final String MULTICLASS_WARN = "warn";
    private static final String SOLVER_LIBLINEAR = "liblinear";

    public LogisticRegression(String module, String name) {
        super(module, name);
    }

    @Override
    public Model encodeModel(Schema schema) {
        int[] shape;
        String sklearnVersion = this.getSkLearnVersion();
        String multiClass = this.getMultiClass();
        if (Objects.equals(MULTICLASS_AUTO, multiClass)) {
            shape = this.getCoefShape();
            String solver = this.getSolver();
            multiClass = sklearnVersion != null && VersionUtil.compareVersion(sklearnVersion, "0.22") >= 0 ? LogisticRegression.getAutoMultiClass(solver, shape) : null;
        } else if (Objects.equals(MULTICLASS_DEPRECATED, multiClass)) {
            shape = this.getCoefShape();
            multiClass = sklearnVersion != null && VersionUtil.compareVersion(sklearnVersion, "1.5.0") >= 0 ? LogisticRegression.getAutoMultiClass(null, shape) : null;
        }
        if (multiClass == null) {
            throw new AttributeException("Attribute '" + ClassDictUtil.formatMember((ClassDict)this, (String)"multi_class") + "' must be set to one of " + PythonFormatterUtil.formatValue((Object)MULTICLASS_OVR) + " or " + PythonFormatterUtil.formatValue((Object)MULTICLASS_MULTINOMIAL) + " values");
        }
        switch (multiClass) {
            case "multinomial": {
                return this.encodeMultinomialModel(schema);
            }
            case "ovr": {
                return this.encodeOvRModel(schema);
            }
        }
        throw new IllegalArgumentException(multiClass);
    }

    private Model encodeMultinomialModel(Schema schema) {
        String sklearnVersion = this.getSkLearnVersion();
        int[] shape = this.getCoefShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        List<Number> coef = this.getCoef();
        List<Number> intercept = this.getIntercept();
        PMMLEncoder encoder = schema.getEncoder();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        if (numberOfClasses == 1) {
            boolean corrected;
            SchemaUtil.checkSize((int)2, (DiscreteLabel)categoricalLabel);
            boolean bl = corrected = sklearnVersion != null && VersionUtil.compareVersion(sklearnVersion, "0.20") >= 0;
            if (!corrected) {
                return this.encodeOvRModel(schema);
            }
            Schema segmentSchema = schema.toRelabeledSchema(null);
            RegressionModel firstModel = RegressionModelUtil.createRegression((List)features, (List)CMatrixUtil.getRow(coef, (int)1, (int)numberOfFeatures, (int)0), (Number)intercept.get(0), null, (Schema)segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)"decisionFunction", (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            ContinuousFeature feature = new ContinuousFeature(encoder, "decisionFunction", DataType.DOUBLE);
            RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), Collections.singletonList(-1.0), (Number)0.0).setTargetCategory(categoricalLabel.getValue(0));
            RegressionTable activeRegressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), Collections.singletonList(1.0), (Number)0.0).setTargetCategory(categoricalLabel.getValue(1));
            ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
            regressionTables.add(passiveRegressionTable);
            regressionTables.add(activeRegressionTable);
            RegressionModel secondModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
            MiningModel miningModel = MiningModelUtil.createModelChain(Arrays.asList(firstModel, secondModel), (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING);
            this.encodePredictProbaOutput((Model)miningModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
            return miningModel;
        }
        if (numberOfClasses >= 3) {
            SchemaUtil.checkSize((int)numberOfClasses, (DiscreteLabel)categoricalLabel);
            ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                RegressionTable regressionTable = RegressionModelUtil.createRegressionTable((List)features, (List)CMatrixUtil.getRow(coef, (int)numberOfClasses, (int)numberOfFeatures, (int)i), (Number)intercept.get(i)).setTargetCategory(categoricalLabel.getValue(i));
                regressionTables.add(regressionTable);
            }
            RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
            this.encodePredictProbaOutput((Model)regressionModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
            return regressionModel;
        }
        throw new IllegalArgumentException();
    }

    private Model encodeOvRModel(Schema schema) {
        return super.encodeModel(schema);
    }

    public String getMultiClass() {
        String multiClass = (String)this.getEnum("multi_class", arg_0 -> ((LogisticRegression)this).getString(arg_0), Arrays.asList(MULTICLASS_AUTO, MULTICLASS_DEPRECATED, MULTICLASS_MULTINOMIAL, MULTICLASS_OVR, MULTICLASS_WARN));
        if (Objects.equals(MULTICLASS_WARN, multiClass)) {
            multiClass = MULTICLASS_OVR;
        }
        return multiClass;
    }

    public String getSolver() {
        return this.getString("solver");
    }

    private static String getAutoMultiClass(String solver, int[] shape) {
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        if (Objects.equals(SOLVER_LIBLINEAR, solver)) {
            return MULTICLASS_OVR;
        }
        if (numberOfClasses == 1) {
            return MULTICLASS_OVR;
        }
        return MULTICLASS_MULTINOMIAL;
    }
}

