/*
 * Decompiled with CFR 0.152.
 */
package xgboost.sklearn;

import java.nio.ByteOrder;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.function.Function;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Value;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.xgboost.ByteOrderUtil;
import org.jpmml.xgboost.HasXGBoostOptions;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.ObjFunction;
import sklearn.Estimator;
import xgboost.sklearn.Booster;
import xgboost.sklearn.HasBooster;

public class BoosterUtil {
    private BoosterUtil() {
    }

    public static <E extends Estimator & HasXGBoostOptions> int getNumberOfFeatures(E estimator) {
        Learner learner = BoosterUtil.getLearner(estimator);
        return learner.num_feature();
    }

    public static <E extends Estimator & HasXGBoostOptions> ObjFunction getObjFunction(E estimator) {
        Learner learner = BoosterUtil.getLearner(estimator);
        return learner.obj();
    }

    public static <E extends Estimator & HasXGBoostOptions> MiningModel encodeBooster(E estimator, Schema schema) {
        Learner learner = BoosterUtil.getLearner(estimator);
        Integer bestNTreeLimit = (Integer)estimator.getOptionalScalar("best_ntree_limit");
        Number missing = (Number)estimator.getOptionalScalar("missing");
        Boolean compact = (Boolean)estimator.getOption("compact", (Object)Boolean.TRUE);
        Boolean numeric = (Boolean)estimator.getOption("numeric", (Object)Boolean.TRUE);
        Boolean prune = (Boolean)estimator.getOption("prune", (Object)Boolean.TRUE);
        Integer ntreeLimit = (Integer)estimator.getOption("ntree_limit", (Object)bestNTreeLimit);
        LinkedHashMap<String, Comparable<Boolean>> options = new LinkedHashMap<String, Comparable<Boolean>>();
        options.put("compact", compact);
        options.put("numeric", numeric);
        options.put("prune", prune);
        options.put("ntree_limit", ntreeLimit);
        Schema xgbSchema = learner.toXGBoostSchema(numeric.booleanValue(), schema);
        if (missing != null && !ValueUtil.isNaN((Object)missing)) {
            xgbSchema = BoosterUtil.toBoosterSchema(missing, xgbSchema);
        }
        MiningModel miningModel = learner.encodeMiningModel(options, xgbSchema);
        return miningModel;
    }

    private static <E extends Estimator & HasXGBoostOptions> Learner getLearner(E estimator) {
        Booster booster = ((HasBooster)estimator).getBooster();
        String byteOrder = (String)estimator.getOption("byte_order", (Object)ByteOrder.nativeOrder().toString());
        String charset = (String)estimator.getOption("charset", null);
        return booster.getLearner(ByteOrderUtil.forValue((String)byteOrder), charset);
    }

    private static Schema toBoosterSchema(final Number missing, Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                Field field = continuousFeature.getField();
                if (field instanceof DataField) {
                    DataField dataField = (DataField)field;
                    PMMLUtil.addValues((Field)dataField, (Value.Property)Value.Property.MISSING, Collections.singletonList(missing));
                    return continuousFeature;
                }
                PMMLEncoder encoder = continuousFeature.getEncoder();
                Apply expression = PMMLUtil.createApply((String)"if", (Expression[])new Expression[]{PMMLUtil.createApply((String)"and", (Expression[])new Expression[]{PMMLUtil.createApply((String)"isNotMissing", (Expression[])new Expression[]{continuousFeature.ref()}), PMMLUtil.createApply((String)"notEqual", (Expression[])new Expression[]{continuousFeature.ref(), PMMLUtil.createConstant((Number)missing)})}), continuousFeature.ref()});
                DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create((String)"booster", (Object[])new Object[]{continuousFeature}), OpType.CONTINUOUS, continuousFeature.getDataType(), (Expression)expression);
                return new ContinuousFeature(encoder, (Field)derivedField);
            }
        };
        return schema.toTransformedSchema((Function)function);
    }
}

