/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.devsmart.ubjson.GsonUtil;
import com.devsmart.ubjson.UBObject;
import com.devsmart.ubjson.UBReader;
import com.devsmart.ubjson.UBValue;
import com.google.common.collect.Iterables;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.io.DataInput;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Value;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.TreeModelPruner;
import org.jpmml.xgboost.AFT;
import org.jpmml.xgboost.BinaryLoadable;
import org.jpmml.xgboost.BinomialLogisticRegression;
import org.jpmml.xgboost.Dart;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.GBTree;
import org.jpmml.xgboost.GeneralizedLinearRegression;
import org.jpmml.xgboost.HingeClassification;
import org.jpmml.xgboost.JSONLoadable;
import org.jpmml.xgboost.LambdaMART;
import org.jpmml.xgboost.LinearRegression;
import org.jpmml.xgboost.LogisticRegression;
import org.jpmml.xgboost.MultinomialLogisticRegression;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.PoissonRegression;
import org.jpmml.xgboost.UBJSONLoadable;
import org.jpmml.xgboost.UBJSONUtil;
import org.jpmml.xgboost.XGBoostDataInput;
import org.jpmml.xgboost.XGBoostEncoder;
import org.jpmml.xgboost.visitors.TreeModelCompactor;

public class Learner
implements BinaryLoadable,
JSONLoadable,
UBJSONLoadable {
    private float base_score;
    private int num_feature;
    private int num_class;
    private int contain_extra_attrs;
    private int contain_eval_metrics;
    private int major_version;
    private int minor_version;
    private int num_target;
    private int base_score_estimated;
    private ObjFunction obj;
    private GBTree gbtree;
    private Map<String, String> attributes = null;
    private String[] feature_names = null;
    private String[] feature_types = null;
    private String[] metrics = null;

    @Override
    public void loadBinary(XGBoostDataInput input) throws IOException {
        this.base_score = input.readFloat();
        this.num_feature = input.readInt();
        this.num_class = input.readInt();
        this.contain_extra_attrs = input.readInt();
        this.contain_eval_metrics = input.readInt();
        this.major_version = input.readInt();
        this.minor_version = input.readInt();
        if (this.major_version < 0 || this.major_version > 2) {
            throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
        }
        this.num_target = Math.max(input.readInt(), 1);
        this.base_score_estimated = input.readInt();
        input.readReserved(25);
        String name_obj = input.readString();
        this.obj = this.parseObjective(name_obj);
        this.base_score = this.major_version >= 1 ? this.obj.probToMargin(this.base_score) + 0.0f : this.base_score;
        String name_gbm = input.readString();
        this.gbtree = this.parseGradientBooster(name_gbm);
        this.gbtree.loadBinary(input);
        if (this.contain_extra_attrs != 0) {
            this.attributes = input.readStringMap();
        }
        if (this.major_version >= 1) {
            return;
        }
        if (this.obj instanceof PoissonRegression) {
            try {
                String max_delta_step = input.readString();
            }
            catch (EOFException eOFException) {
                // empty catch block
            }
        }
        if (this.contain_eval_metrics != 0) {
            this.metrics = input.readStringVector();
        }
    }

    @Override
    public void loadJSON(JsonObject root) {
        UBValue value = GsonUtil.toUBValue((JsonElement)root);
        this.loadUBJSON(value.asObject());
    }

    @Override
    public void loadUBJSON(UBObject root) {
        if (!root.containsKey((Object)"version")) {
            throw new IllegalArgumentException("Property \"version\" not found among " + root.keySet());
        }
        int[] version = UBJSONUtil.toIntArray(root.get((Object)"version"));
        this.major_version = version[0];
        this.minor_version = version[1];
        if (this.major_version < 1 || this.major_version > 2) {
            throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
        }
        UBObject learner = root.get((Object)"learner").asObject();
        UBObject learnerModelParam = learner.get((Object)"learner_model_param").asObject();
        this.base_score = learnerModelParam.get((Object)"base_score").asFloat32();
        this.num_feature = learnerModelParam.get((Object)"num_feature").asInt();
        this.num_class = learnerModelParam.get((Object)"num_class").asInt();
        this.num_target = learnerModelParam.containsKey((Object)"num_target") ? learnerModelParam.get((Object)"num_target").asInt() : 1;
        UBObject objective = learner.get((Object)"objective").asObject();
        String name_obj = objective.get((Object)"name").asString();
        this.obj = this.parseObjective(name_obj);
        this.base_score = this.obj.probToMargin(this.base_score) + 0.0f;
        UBObject gradientBooster = learner.get((Object)"gradient_booster").asObject();
        String name_gbm = gradientBooster.get((Object)"name").asString();
        this.gbtree = this.parseGradientBooster(name_gbm);
        this.gbtree.loadUBJSON(gradientBooster);
        if (learner.containsKey((Object)"attributes")) {
            String[] keys;
            UBObject attributes = learner.get((Object)"attributes").asObject();
            this.attributes = new HashMap<String, String>();
            for (String key : keys = new String[]{"best_iteration", "best_score"}) {
                if (!attributes.containsKey((Object)key)) continue;
                this.attributes.put(key, attributes.get((Object)key).asString());
            }
        }
        if (learner.containsKey((Object)"feature_names")) {
            this.feature_names = UBJSONUtil.toStringArray(learner.get((Object)"feature_names"));
        }
        if (learner.containsKey((Object)"feature_types")) {
            this.feature_types = UBJSONUtil.toStringArray(learner.get((Object)"feature_types"));
        }
    }

    public <DIS extends InputStream> void loadBinary(DIS is, String charset) throws IOException {
        long offset;
        boolean hasSerializationHeader = Learner.consumeHeader(is, "CONFIG-offset:");
        if (hasSerializationHeader && (offset = ((DataInput)((Object)is)).readLong()) < 0L) {
            throw new IOException();
        }
        boolean hasBInfHeader = Learner.consumeHeader(is, "binf");
        if (hasBInfHeader) {
            // empty if block
        }
        try (XGBoostDataInput input = new XGBoostDataInput(is, charset);){
            this.loadBinary(input);
            if (hasSerializationHeader) {
            } else {
                int eof = is.read();
                if (eof != -1) {
                    throw new IOException();
                }
            }
        }
    }

    public void loadJSON(InputStream is, String charset, String jsonPath) throws IOException {
        JsonParser parser = new JsonParser();
        if (charset == null) {
            charset = "UTF-8";
        }
        try (InputStreamReader reader = new InputStreamReader(is, charset);){
            JsonElement element = parser.parse((Reader)reader);
            JsonObject object = element.getAsJsonObject();
            String[] names = jsonPath.split("\\.");
            for (int i = 0; i < names.length; ++i) {
                String name = names[i];
                if (i == 0 && "$".equals(name)) continue;
                JsonElement childElement = object.get(name);
                if (childElement == null) {
                    throw new IllegalArgumentException("Property \"" + name + "\" not among " + object.keySet());
                }
                object = childElement.getAsJsonObject();
            }
            this.loadJSON(object);
            int eof = is.read();
            if (eof != -1) {
                throw new IOException();
            }
        }
    }

    public void loadUBJSON(InputStream is, String jsonPath) throws IOException {
        try (UBReader reader = new UBReader(is);){
            UBObject object = reader.read().asObject();
            String[] names = jsonPath.split("\\.");
            for (int i = 0; i < names.length; ++i) {
                String name = names[i];
                if (i == 0 && "$".equals(name)) continue;
                UBValue childValue = object.get((Object)name);
                if (childValue == null) {
                    throw new IllegalArgumentException("Property \"" + name + "\" not among " + object.keySet());
                }
                object = childValue.asObject();
            }
            this.loadUBJSON(object);
            int eof = is.read();
            if (eof != -1) {
                throw new IOException();
            }
        }
    }

    public FeatureMap encodeFeatureMap() {
        if (this.feature_names == null || this.feature_types == null) {
            return null;
        }
        FeatureMap result = new FeatureMap();
        for (int i = 0; i < this.feature_names.length; ++i) {
            result.addEntry(this.feature_names[i], this.feature_types[i]);
        }
        return result;
    }

    public Schema encodeSchema(String targetName, List<String> targetCategories, FeatureMap featureMap, XGBoostEncoder encoder) {
        if (targetName == null) {
            targetName = "_target";
        }
        Label label = this.encodeLabel(targetName, targetCategories, encoder);
        List<Feature> features = featureMap.encodeFeatures((PMMLEncoder)encoder);
        return new Schema((PMMLEncoder)encoder, label, features);
    }

    public Label encodeLabel(String targetName, List<String> targetCategories, XGBoostEncoder encoder) {
        if (this.num_target == 1) {
            return this.obj.encodeLabel(targetName, targetCategories, (PMMLEncoder)encoder);
        }
        if (this.num_target >= 2) {
            ArrayList<Label> labels = new ArrayList<Label>();
            for (int i = 0; i < this.num_target; ++i) {
                Label label = this.obj.encodeLabel(targetName + String.valueOf(i + 1), targetCategories, (PMMLEncoder)encoder);
                labels.add(label);
            }
            return new MultiLabel(labels);
        }
        throw new IllegalArgumentException();
    }

    public Schema toXGBoostSchema(final boolean numeric, final Schema schema) {
        FeatureTransformer function = new FeatureTransformer(){
            private List<? extends Feature> features;
            {
                this.features = schema.getFeatures();
            }

            @Override
            public int getSplitIndex(Feature feature) {
                return this.features.indexOf(feature);
            }

            @Override
            public Feature transformNumerical(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof MissingValueFeature) {
                    MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                    return missingValueFeature;
                }
                if (feature instanceof ThresholdFeature && !numeric) {
                    ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                    return thresholdFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                DataType dataType = continuousFeature.getDataType();
                switch (dataType) {
                    case INTEGER: 
                    case FLOAT: {
                        break;
                    }
                    case DOUBLE: {
                        continuousFeature = continuousFeature.toContinuousFeature(DataType.FLOAT);
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Expected integer, float or double data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
                    }
                }
                return continuousFeature;
            }

            @Override
            public Feature transformCategorical(Feature feature) {
                if (feature instanceof CategoricalFeature) {
                    CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                    return categoricalFeature;
                }
                throw new IllegalArgumentException();
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    public Schema toValueFilteredSchema(final Number missing, final Schema schema) {
        FeatureTransformer function = new FeatureTransformer(){
            private List<? extends Feature> features;
            {
                this.features = schema.getFeatures();
            }

            @Override
            public int getSplitIndex(Feature feature) {
                return this.features.indexOf(feature);
            }

            @Override
            public Feature transformNumerical(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof MissingValueFeature) {
                    MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                    return missingValueFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                Field field = continuousFeature.getField();
                if (field instanceof DataField) {
                    DataField dataField = (DataField)field;
                    if (ValueUtil.isNaN((Object)missing)) {
                        DataType dataType = dataField.getDataType();
                        switch (dataType) {
                            case FLOAT: 
                            case DOUBLE: {
                                break;
                            }
                            default: {
                                return continuousFeature;
                            }
                        }
                    }
                    PMMLUtil.addValues((Field)dataField, (Value.Property)Value.Property.MISSING, Collections.singletonList(missing));
                    return continuousFeature;
                }
                if (ValueUtil.isNaN((Object)missing)) {
                    return continuousFeature;
                }
                PMMLEncoder encoder = continuousFeature.getEncoder();
                Apply expression = PMMLUtil.createApply((String)"if", (Expression[])new Expression[]{PMMLUtil.createApply((String)"and", (Expression[])new Expression[]{PMMLUtil.createApply((String)"isNotMissing", (Expression[])new Expression[]{continuousFeature.ref()}), PMMLUtil.createApply((String)"notEqual", (Expression[])new Expression[]{continuousFeature.ref(), PMMLUtil.createConstant((Number)missing)})}), continuousFeature.ref()});
                DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create((String)"filter", (Object[])new Object[]{continuousFeature, missing}), OpType.CONTINUOUS, continuousFeature.getDataType(), (Expression)expression);
                return new ContinuousFeature(encoder, (Field)derivedField);
            }

            @Override
            public Feature transformCategorical(Feature feature) {
                if (feature instanceof CategoricalFeature) {
                    CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                    return categoricalFeature;
                }
                throw new IllegalArgumentException();
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    public PMML encodePMML(Map<String, ?> options, String targetName, List<String> targetCategories, FeatureMap featureMap) {
        XGBoostEncoder encoder = new XGBoostEncoder();
        FeatureMap embeddedFeatureMap = this.encodeFeatureMap();
        if (embeddedFeatureMap != null) {
            embeddedFeatureMap.update(featureMap);
            featureMap = embeddedFeatureMap;
        }
        Schema schema = this.encodeSchema(targetName, targetCategories, featureMap, encoder);
        MiningModel miningModel = this.encodeMiningModel(options, schema);
        PMML pmml = encoder.encodePMML((Model)miningModel);
        return pmml;
    }

    public MiningModel encodeMiningModel(Map<String, ?> options, Schema schema) {
        TreeModelCompactor visitor;
        Number missing = (Number)options.get("missing");
        Boolean compact = (Boolean)options.get("compact");
        Boolean numeric = (Boolean)options.get("numeric");
        Boolean prune = (Boolean)options.get("prune");
        Integer ntreeLimit = (Integer)options.get("ntree_limit");
        if (numeric == null) {
            numeric = Boolean.TRUE;
        }
        if (missing != null) {
            schema = this.toValueFilteredSchema(missing, schema);
        }
        MiningModel miningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, ntreeLimit, numeric, schema).setAlgorithmName("XGBoost (" + this.gbtree.getAlgorithmName() + ")");
        if (Boolean.TRUE.equals(compact)) {
            if (Boolean.FALSE.equals(numeric)) {
                throw new IllegalArgumentException("Conflicting XGBoost options");
            }
            visitor = new TreeModelCompactor();
            visitor.applyTo((Visitable)miningModel);
        }
        if (Boolean.TRUE.equals(prune)) {
            visitor = new TreeModelPruner();
            visitor.applyTo((Visitable)miningModel);
        }
        return miningModel;
    }

    public int num_feature() {
        return this.num_feature;
    }

    public int num_class() {
        return this.num_class;
    }

    public ObjFunction obj() {
        return this.obj;
    }

    public GBTree gbtree() {
        return this.gbtree;
    }

    public String getAttribute(String key) {
        if (this.attributes != null && this.attributes.containsKey(key)) {
            return this.attributes.get(key);
        }
        return null;
    }

    public Integer getBestIteration() {
        String bestIteration = this.getAttribute("best_iteration");
        if (bestIteration != null) {
            return Integer.valueOf(bestIteration);
        }
        return null;
    }

    public Double getBestScore() {
        String bestScore = this.getAttribute("best_score");
        if (bestScore != null) {
            return Double.valueOf(bestScore);
        }
        return null;
    }

    private GBTree parseGradientBooster(String name_gbm) {
        switch (name_gbm) {
            case "gbtree": {
                return new GBTree();
            }
            case "dart": {
                return new Dart();
            }
        }
        throw new IllegalArgumentException(name_gbm);
    }

    private ObjFunction parseObjective(String name_obj) {
        switch (name_obj) {
            case "reg:linear": 
            case "reg:pseudohubererror": 
            case "reg:squarederror": 
            case "reg:squaredlogerror": {
                return new LinearRegression(name_obj);
            }
            case "reg:logistic": {
                return new LogisticRegression(name_obj);
            }
            case "reg:gamma": 
            case "reg:tweedie": {
                return new GeneralizedLinearRegression(name_obj);
            }
            case "count:poisson": {
                return new PoissonRegression(name_obj);
            }
            case "binary:hinge": {
                return new HingeClassification(name_obj);
            }
            case "binary:logistic": {
                return new BinomialLogisticRegression(name_obj);
            }
            case "rank:map": 
            case "rank:ndcg": 
            case "rank:pairwise": {
                return new LambdaMART(name_obj);
            }
            case "survival:aft": {
                return new AFT(name_obj);
            }
            case "multi:softmax": 
            case "multi:softprob": {
                return new MultinomialLogisticRegression(name_obj, this.num_class);
            }
        }
        throw new IllegalArgumentException(name_obj);
    }

    private static <DIS extends InputStream> boolean consumeHeader(DIS is, String header) throws IOException {
        byte[] headerBytes = header.getBytes(StandardCharsets.UTF_8);
        byte[] buffer = new byte[headerBytes.length];
        is.mark(buffer.length);
        ((DataInput)((Object)is)).readFully(buffer);
        boolean equals = Arrays.equals(headerBytes, buffer);
        if (!equals) {
            is.reset();
        }
        return equals;
    }

    private abstract class FeatureTransformer
    implements Function<Feature, Feature> {
        private FeatureTransformer() {
        }

        public abstract int getSplitIndex(Feature var1);

        public abstract Feature transformNumerical(Feature var1);

        public abstract Feature transformCategorical(Feature var1);

        @Override
        public Feature apply(Feature feature) {
            int splitIndex = this.getSplitIndex(feature);
            Integer splitType = this.getSplitType(splitIndex);
            if (splitType == null) {
                return feature;
            }
            switch (splitType) {
                case 0: {
                    return this.transformNumerical(feature);
                }
                case 1: {
                    return this.transformCategorical(feature);
                }
            }
            throw new IllegalArgumentException();
        }

        private Integer getSplitType(int splitIndex) {
            Set<Integer> splitTypes = Learner.this.gbtree.getSplitType(splitIndex);
            if (splitTypes.size() == 0) {
                return null;
            }
            if (splitTypes.size() == 1) {
                return (Integer)Iterables.getOnlyElement(splitTypes);
            }
            throw new IllegalArgumentException();
        }
    }
}

