/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.RegTree;

public abstract class Classification
extends ObjFunction {
    private int num_class;

    public Classification(String name, int num_class) {
        super(name);
        this.num_class = num_class;
    }

    @Override
    public Label encodeLabel(String targetName, List<?> targetCategories, PMMLEncoder encoder) {
        DataField dataField;
        if (targetCategories == null) {
            targetCategories = LabelUtil.createTargetCategories((int)this.num_class);
            dataField = encoder.createDataField(targetName, OpType.CATEGORICAL, DataType.INTEGER, targetCategories);
        } else {
            if (targetCategories.size() != this.num_class) {
                throw new IllegalArgumentException("Expected " + this.num_class + " target categories, got " + targetCategories.size() + " target categories");
            }
            dataField = encoder.createDataField(targetName, OpType.CATEGORICAL, DataType.STRING, targetCategories);
        }
        return new CategoricalLabel((Field)dataField);
    }

    @Override
    public MiningModel encodeMiningModel(int targetIndex, List<RegTree> trees, List<Float> weights, float base_score, Integer ntreeLimit, boolean numeric, Schema schema) {
        MiningModel miningModel = this.encodeMiningModel(trees, weights, base_score, ntreeLimit, numeric, schema);
        if (targetIndex != -1) {
            Model finalModel = MiningModelUtil.getFinalModel((MiningModel)miningModel);
            Output output = finalModel.getOutput();
            if (output == null || !output.hasOutputFields()) {
                throw new IllegalArgumentException();
            }
            List outputFields = output.getOutputFields();
            outputFields.removeIf(outputField -> outputField.getResultFeature() == ResultFeature.PROBABILITY);
            CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
            List values = categoricalLabel.getValues();
            values.stream().map(value -> ModelUtil.createProbabilityField((String)FieldNameUtil.create((String)"probability", (Object[])new Object[]{categoricalLabel.getName(), value}), (DataType)DataType.FLOAT, (Object)value)).forEach(outputFields::add);
        }
        return miningModel;
    }

    public int num_class() {
        return this.num_class;
    }
}

