/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sparkml.MatrixUtil;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.VectorUtil;
import org.jpmml.sparkml.model.RegressionTableUtil;

public class LinearModelUtil {
    public static <C extends ModelConverter<?>> RegressionModel createRegression(C converter, Vector coefficients, double intercept, Schema schema) {
        ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
        ArrayList<Feature> features = new ArrayList<Feature>(schema.getFeatures());
        ArrayList<Double> featureCoefficients = new ArrayList<Double>(VectorUtil.toList(coefficients));
        RegressionTableUtil.simplify(converter, null, features, featureCoefficients);
        return RegressionModelUtil.createRegression(features, featureCoefficients, (Number)intercept, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.NONE, (Schema)schema);
    }

    public static <C extends ModelConverter<?>> RegressionModel createBinaryLogisticClassification(C converter, Vector coefficients, double intercept, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        ArrayList<Feature> features = new ArrayList<Feature>(schema.getFeatures());
        ArrayList<Double> featureCoefficients = new ArrayList<Double>(VectorUtil.toList(coefficients));
        RegressionTableUtil.simplify(converter, null, features, featureCoefficients);
        return RegressionModelUtil.createBinaryLogisticClassification(features, featureCoefficients, (Number)intercept, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)true, (Schema)schema);
    }

    public static <C extends ModelConverter<?>> RegressionModel createSoftmaxClassification(C converter, Matrix coefficients, Vector intercepts, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        MatrixUtil.checkRows(categoricalLabel.size(), coefficients);
        ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            Object targetCategory = categoricalLabel.getValue(i);
            ArrayList<Feature> features = new ArrayList<Feature>(schema.getFeatures());
            ArrayList<Double> featureCoefficients = new ArrayList<Double>(MatrixUtil.getRow(coefficients, i));
            RegressionTableUtil.simplify(converter, targetCategory, features, featureCoefficients);
            double intercept = intercepts.apply(i);
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, featureCoefficients, (Number)intercept).setTargetCategory(targetCategory);
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
        return regressionModel;
    }
}

