/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.xgboost.Classification;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.GBTree;
import org.jpmml.xgboost.LinearRegression;
import org.jpmml.xgboost.LogisticClassification;
import org.jpmml.xgboost.LogisticRegression;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.Regression;
import org.jpmml.xgboost.SoftMaxClassification;
import org.jpmml.xgboost.XGBoostDataInput;

public class Learner {
    private float base_score;
    private int num_features;
    private int num_class;
    private ObjFunction obj;
    private GBTree gbtree;

    public void load(XGBoostDataInput input) throws IOException {
        String name_gbm;
        String name_obj;
        this.base_score = input.readFloat();
        this.num_features = input.readInt();
        this.num_class = input.readInt();
        input.readReserved(31);
        switch (name_obj = input.readString()) {
            case "reg:linear": {
                this.obj = new LinearRegression();
                break;
            }
            case "reg:logistic": {
                this.obj = new LogisticRegression();
                break;
            }
            case "binary:logistic": {
                this.obj = new LogisticClassification();
                break;
            }
            case "multi:softmax": 
            case "multi:softprob": {
                this.obj = new SoftMaxClassification(this.num_class);
                break;
            }
            default: {
                throw new IllegalArgumentException(name_obj);
            }
        }
        switch (name_gbm = input.readString()) {
            case "gbtree": {
                break;
            }
            default: {
                throw new IllegalArgumentException(name_gbm);
            }
        }
        this.gbtree = new GBTree();
        this.gbtree.load(input);
    }

    public PMML encodePMML(String targetName, List<String> targetCategories, FeatureMap featureMap) {
        DataField dataField = this.obj.getDataField();
        if (targetName != null) {
            dataField.setName(FieldName.create((String)targetName));
        }
        if (this.obj instanceof Regression) {
            Regression regression = (Regression)this.obj;
            if (targetCategories != null) {
                throw new IllegalArgumentException();
            }
        } else if (this.obj instanceof Classification) {
            Classification classification = (Classification)this.obj;
            if (targetCategories != null) {
                classification.updateTargetCategories(targetCategories);
            }
            if ((targetCategories = classification.getTargetCategories()) != null && targetCategories.size() > 0) {
                List values = dataField.getValues();
                if (values.size() > 0) {
                    values.clear();
                }
                values.addAll(PMMLUtil.createValues(targetCategories));
            }
        }
        MiningModel miningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, featureMap);
        ArrayList<DataField> dataFields = new ArrayList<DataField>();
        dataFields.add(dataField);
        dataFields.addAll(featureMap.getDataFields());
        DataDictionary dataDictionary = new DataDictionary(dataFields);
        PMML pmml = new PMML("4.2", PMMLUtil.createHeader((String)"JPMML-XGBoost", (String)"1.0-SNAPSHOT"), dataDictionary).addModels(new Model[]{miningModel});
        return pmml;
    }
}

