/*
 * Decompiled with CFR 0.152.
 */
package sklearn.linear_model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.MatrixUtil;
import sklearn.Classifier;
import sklearn.EstimatorUtil;
import sklearn.linear_model.RegressionModelUtil;

public abstract class BaseLinearClassifier
extends Classifier {
    public BaseLinearClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getCoefShape();
        return shape[1];
    }

    @Override
    public boolean requiresContinuousInput() {
        return false;
    }

    public MiningModel encodeModel(Schema schema) {
        int[] shape = this.getCoefShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        boolean hasProbabilityDistribution = this.hasProbabilityDistribution();
        List<? extends Number> coefficients = this.getCoef();
        List<? extends Number> intercepts = this.getIntercept();
        List targetCategories = schema.getTargetCategories();
        Schema segmentSchema = schema.toAnonymousSchema();
        if (numberOfClasses == 1) {
            if (targetCategories.size() != 2) {
                throw new IllegalArgumentException();
            }
            RegressionModel regressionModel = BaseLinearClassifier.encodeCategoryRegressor((String)targetCategories.get(1), MatrixUtil.getRow(coefficients, numberOfClasses, numberOfFeatures, 0), intercepts.get(0), null, segmentSchema);
            return MiningModelUtil.createBinaryLogisticClassification((Schema)schema, (Model)regressionModel, (double)-1.0, (boolean)hasProbabilityDistribution);
        }
        if (numberOfClasses >= 2) {
            if (targetCategories.size() != numberOfClasses) {
                throw new IllegalArgumentException();
            }
            ArrayList<RegressionModel> regressionModels = new ArrayList<RegressionModel>();
            for (int i = 0; i < targetCategories.size(); ++i) {
                RegressionModel regressionModel = BaseLinearClassifier.encodeCategoryRegressor((String)targetCategories.get(i), MatrixUtil.getRow(coefficients, numberOfClasses, numberOfFeatures, i), intercepts.get(i), "logit", segmentSchema);
                regressionModels.add(regressionModel);
            }
            return MiningModelUtil.createClassification((Schema)schema, regressionModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)hasProbabilityDistribution);
        }
        throw new IllegalArgumentException();
    }

    @Override
    public Set<DefineFunction> encodeDefineFunctions() {
        return Collections.singleton(EstimatorUtil.encodeLogitFunction());
    }

    public List<? extends Number> getCoef() {
        return ClassDictUtil.getArray(this, "coef_");
    }

    public List<? extends Number> getIntercept() {
        return ClassDictUtil.getArray(this, "intercept_");
    }

    private int[] getCoefShape() {
        return ClassDictUtil.getShape(this, "coef_", 2);
    }

    private static RegressionModel encodeCategoryRegressor(String targetCategory, List<? extends Number> coefficients, Number intercept, String outputTransformation, Schema schema) {
        OutputField decisionFunction = new OutputField(FieldName.create((String)("decisionFunction_" + targetCategory)), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.PREDICTED_VALUE).setFinalResult(Boolean.valueOf(false));
        Output output = new Output().addOutputFields(new OutputField[]{decisionFunction});
        if (outputTransformation != null) {
            OutputField transformedDecisionFunction = new OutputField(FieldName.create((String)(outputTransformation + "DecisionFunction_" + targetCategory)), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(Boolean.valueOf(false)).setExpression((Expression)PMMLUtil.createApply((String)outputTransformation, (Expression[])new Expression[]{new FieldRef(decisionFunction.getName())}));
            output.addOutputFields(new OutputField[]{transformedDecisionFunction});
        }
        RegressionModel regressionModel = RegressionModelUtil.encodeRegressionModel(coefficients, intercept, schema).setOutput(output);
        return regressionModel;
    }
}

