/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.xgboost.Classification;
import org.jpmml.xgboost.ObjFunction;

public class SoftMaxClassification
extends Classification {
    private static final ObjFunction.Transformation TRANSFORMATION = new ObjFunction.Transformation(){

        @Override
        public Expression createExpression(FieldName name) {
            return PMMLUtil.createApply((String)"exp", (Expression[])new Expression[]{new FieldRef(name)});
        }
    };

    public SoftMaxClassification(int num_class) {
        super(num_class);
        if (num_class < 3) {
            throw new IllegalArgumentException("Multi-class classification requires three or more target categories");
        }
    }

    @Override
    public MiningModel encodeMiningModel(Segmentation segmentation, float base_score, Schema schema) {
        Schema segmentSchema = schema.toAnonymousSchema();
        Function<Segment, Model> function = new Function<Segment, Model>(){

            public Model apply(Segment segment) {
                return segment.getModel();
            }
        };
        List models = Lists.transform((List)segmentation.getSegments(), (Function)function);
        ArrayList<MiningModel> miningModels = new ArrayList<MiningModel>();
        CategoricalLabel categoricalLabel = (CategoricalLabel)segmentSchema.getLabel();
        int columns = categoricalLabel.size();
        int rows = models.size() / columns;
        for (int i = 0; i < columns; ++i) {
            MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Schema)segmentSchema)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, SoftMaxClassification.getColumn(models, i, rows, columns))).setTargets(SoftMaxClassification.createTargets(base_score, segmentSchema)).setOutput(SoftMaxClassification.createOutput(TRANSFORMATION, categoricalLabel.getValue(i)));
            miningModels.add(miningModel);
        }
        return MiningModelUtil.createClassification((Schema)schema, miningModels, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)true);
    }

    private static <E> List<E> getColumn(List<E> values, int index, int rows, int columns) {
        if (values.size() != rows * columns) {
            throw new IllegalArgumentException();
        }
        ArrayList<E> result = new ArrayList<E>();
        for (int row = 0; row < rows; ++row) {
            result.add(values.get(row * columns + index));
        }
        return result;
    }
}

