/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.xgboost;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.function.Function;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams;
import org.apache.spark.ml.Model;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.XGBoostUtil;

public class BoosterUtil {
    private BoosterUtil() {
    }

    public static <M extends Model<M> & GeneralParams, C extends ModelConverter<M>> MiningModel encodeBooster(C converter, Booster booster, Schema schema) {
        Float missing;
        Learner learner;
        byte[] bytes;
        Model model = converter.getModel();
        try {
            bytes = booster.toByteArray();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        try (ByteArrayInputStream is = new ByteArrayInputStream(bytes);){
            learner = XGBoostUtil.loadLearner((InputStream)is);
        }
        catch (IOException ioe) {
            throw new RuntimeException(ioe);
        }
        Boolean inputFloat = (Boolean)converter.getOption("input_float", null);
        if (Boolean.TRUE.equals(inputFloat)) {
            Function<Feature, Feature> function = new Function<Feature, Feature>(){

                @Override
                public Feature apply(Feature feature) {
                    if (feature instanceof ContinuousFeature) {
                        ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                        DataType dataType = continuousFeature.getDataType();
                        switch (dataType) {
                            case INTEGER: 
                            case FLOAT: {
                                break;
                            }
                            case DOUBLE: {
                                Field field = continuousFeature.getField();
                                field.setDataType(DataType.FLOAT);
                                return new ContinuousFeature(continuousFeature.getEncoder(), field);
                            }
                        }
                    }
                    return feature;
                }
            };
            schema = schema.toTransformedSchema((Function)function);
        }
        if ((missing = Float.valueOf(((GeneralParams)model).getMissing())).isNaN()) {
            missing = null;
        }
        LinkedHashMap<String, Object> options = new LinkedHashMap<String, Object>();
        options.put("missing", converter.getOption("missing", (Object)missing));
        options.put("compact", converter.getOption("compact", (Object)false));
        options.put("numeric", converter.getOption("numeric", (Object)true));
        options.put("prune", converter.getOption("prune", (Object)false));
        options.put("ntree_limit", converter.getOption("ntree_limit", null));
        Boolean numeric = (Boolean)options.get("numeric");
        Schema xgbSchema = learner.toXGBoostSchema(numeric.booleanValue(), schema);
        return learner.encodeMiningModel(options, xgbSchema);
    }
}

