/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.xgboost.Classification;
import org.jpmml.xgboost.RegTree;

public class MultinomialLogisticRegression
extends Classification {
    public MultinomialLogisticRegression(String name, int num_class) {
        super(name, num_class);
        if (num_class < 2) {
            throw new IllegalArgumentException("Multi-class classification requires two or more target categories");
        }
    }

    @Override
    public MiningModel encodeMiningModel(List<RegTree> trees, List<Float> weights, float base_score, Integer ntreeLimit, Schema schema) {
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.FLOAT);
        ArrayList<MiningModel> miningModels = new ArrayList<MiningModel>();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        int columns = categoricalLabel.size();
        int rows = trees.size() / columns;
        for (int i = 0; i < columns; ++i) {
            MiningModel miningModel = MultinomialLogisticRegression.createMiningModel(CMatrixUtil.getColumn(trees, (int)rows, (int)columns, (int)i), weights != null ? CMatrixUtil.getColumn(weights, (int)rows, (int)columns, (int)i) : null, base_score, ntreeLimit, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)FieldNameUtil.create((String)"xgbValue", (Object[])new Object[]{categoricalLabel.getValue(i)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.FLOAT, (Transformation[])new Transformation[0]));
            miningModels.add(miningModel);
        }
        return MiningModelUtil.createClassification(miningModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SOFTMAX, (boolean)true, (Schema)schema);
    }
}

