/*
 * Decompiled with CFR 0.152.
 */
package sklearn.discriminant_analysis;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sklearn.SkLearnUtil;
import sklearn.linear_model.LinearClassifier;

public class LinearDiscriminantAnalysis
extends LinearClassifier {
    public LinearDiscriminantAnalysis(String module, String name) {
        super(module, name);
    }

    @Override
    public Model encodeModel(Schema schema) {
        int[] shape = this.getCoefShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        if (numberOfClasses == 1) {
            return this.encodeBinaryModel(schema);
        }
        return this.encodeMultinomialModel(schema);
    }

    private Model encodeBinaryModel(Schema schema) {
        return super.encodeModel(schema);
    }

    private Model encodeMultinomialModel(Schema schema) {
        boolean corrected;
        String sklearnVersion = this.getSkLearnVersion();
        int[] shape = this.getCoefShape();
        int numberOfClasses = shape[0];
        int numberOfFeatures = shape[1];
        List<? extends Number> coef = this.getCoef();
        List<? extends Number> intercept = this.getIntercept();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        boolean bl = corrected = sklearnVersion != null && SkLearnUtil.compareVersion(sklearnVersion, "0.21") >= 0;
        if (!corrected) {
            return super.encodeModel(schema);
        }
        if (numberOfClasses >= 3) {
            SchemaUtil.checkSize((int)numberOfClasses, (CategoricalLabel)categoricalLabel);
            Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE).toEmptySchema();
            ArrayList<RegressionModel> regressionModels = new ArrayList<RegressionModel>();
            int rows = categoricalLabel.size();
            for (int i = 0; i < rows; ++i) {
                RegressionModel regressionModel = RegressionModelUtil.createRegression((List)features, (List)CMatrixUtil.getRow(coef, (int)numberOfClasses, (int)numberOfFeatures, (int)i), (Number)intercept.get(i), (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.NONE, (Schema)segmentSchema).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldNameUtil.create((String)"decisionFunction", (Object[])new Object[]{categoricalLabel.getValue(i)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
                regressionModels.add(regressionModel);
            }
            return MiningModelUtil.createClassification(regressionModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SOFTMAX, (boolean)true, (Schema)schema);
        }
        throw new IllegalArgumentException();
    }
}

