/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Segmentation;
import org.jpmml.converter.MiningModelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.xgboost.Classification;
import org.jpmml.xgboost.FeatureMap;

public class SoftMaxClassification
extends Classification {
    public SoftMaxClassification(int num_class) {
        super(num_class);
    }

    @Override
    public MiningModel encodeMiningModel(Segmentation segmentation, float base_score, FeatureMap featureMap) {
        DataField dataField = this.getDataField();
        List segments = segmentation.getSegments();
        MiningSchema valueMiningSchema = ModelUtil.createMiningSchema(null, featureMap.getDataFields());
        ArrayList<MiningModel> models = new ArrayList<MiningModel>();
        ArrayList<FieldName> inputFields = new ArrayList<FieldName>();
        List<String> targetCategories = this.getTargetCategories();
        for (int i = 0; i < targetCategories.size(); ++i) {
            String targetCategory = targetCategories.get(i);
            OutputField xgbValue = SoftMaxClassification.createPredictedField(FieldName.create((String)("xgbValue_" + targetCategory)));
            Apply expression = PMMLUtil.createApply((String)"exp", (Expression[])new Expression[]{PMMLUtil.createApply((String)"+", (Expression[])new Expression[]{new FieldRef(xgbValue.getName()), PMMLUtil.createConstant((Object)Float.valueOf(base_score))})});
            OutputField transformedValue = SoftMaxClassification.createTransformedField(FieldName.create((String)("transformedValue_" + targetCategory)), (Expression)expression);
            inputFields.add(transformedValue.getName());
            List valueSegments = SoftMaxClassification.getColumn(segments, i, segments.size() / targetCategories.size(), targetCategories.size());
            Segmentation valueSegmentation = new Segmentation(MultipleModelMethodType.SUM, valueSegments);
            Output valueOutput = new Output().addOutputFields(new OutputField[]{xgbValue, transformedValue});
            MiningModel valueMiningModel = new MiningModel(MiningFunctionType.REGRESSION, valueMiningSchema).setSegmentation(valueSegmentation).setOutput(valueOutput);
            models.add(valueMiningModel);
        }
        Function<DataField, FieldName> function = new Function<DataField, FieldName>(){

            public FieldName apply(DataField dataField) {
                return dataField.getName();
            }
        };
        MiningModel miningModel = MiningModelUtil.createClassification((FieldName)((FieldName)function.apply((Object)dataField)), targetCategories, (List)Lists.transform(featureMap.getDataFields(), (Function)function), models, inputFields, (boolean)true);
        return miningModel;
    }

    private static <E> List<E> getColumn(List<E> values, int index, int rows, int columns) {
        if (values.size() != rows * columns) {
            throw new IllegalArgumentException();
        }
        ArrayList<E> result = new ArrayList<E>();
        for (int row = 0; row < rows; ++row) {
            result.add(values.get(row * columns + index));
        }
        return result;
    }
}

