/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.devsmart.ubjson.GsonUtil;
import com.devsmart.ubjson.UBObject;
import com.devsmart.ubjson.UBValue;
import com.devsmart.ubjson.UBValueFactory;
import com.google.common.primitives.Ints;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ThresholdFeatureUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.xgboost.BinaryLoadable;
import org.jpmml.xgboost.BinaryNode;
import org.jpmml.xgboost.BinaryNodeStat;
import org.jpmml.xgboost.JSONLoadable;
import org.jpmml.xgboost.JSONNode;
import org.jpmml.xgboost.Node;
import org.jpmml.xgboost.NodeStat;
import org.jpmml.xgboost.UBJSONLoadable;
import org.jpmml.xgboost.UBJSONUtil;
import org.jpmml.xgboost.XGBoostDataInput;

public class RegTree
implements BinaryLoadable,
JSONLoadable,
UBJSONLoadable {
    private int num_roots;
    private int num_nodes;
    private int num_deleted;
    private int max_depth;
    private int num_feature;
    private int size_leaf_vector;
    private Node[] nodes;
    private NodeStat[] stats;

    @Override
    public void loadBinary(XGBoostDataInput input) throws IOException {
        this.num_roots = input.readInt();
        this.num_nodes = input.readInt();
        this.num_deleted = input.readInt();
        this.max_depth = input.readInt();
        this.num_feature = input.readInt();
        this.size_leaf_vector = input.readInt();
        input.readReserved(31);
        this.nodes = (Node[])input.readObjectArray(BinaryNode.class, this.num_nodes);
        this.stats = (NodeStat[])input.readObjectArray(BinaryNodeStat.class, this.num_nodes);
    }

    @Override
    public void loadJSON(JsonObject tree) {
        UBValue value = GsonUtil.toUBValue((JsonElement)tree);
        this.loadUBJSON(value.asObject());
    }

    @Override
    public void loadUBJSON(UBObject tree) {
        UBObject treeParam = tree.get((Object)"tree_param").asObject();
        this.num_nodes = treeParam.get((Object)"num_nodes").asInt();
        this.num_deleted = treeParam.get((Object)"num_deleted").asInt();
        this.num_feature = treeParam.get((Object)"num_feature").asInt();
        this.size_leaf_vector = treeParam.get((Object)"size_leaf_vector").asInt();
        int[] parents = UBJSONUtil.toIntArray(tree.get((Object)"parents"));
        int[] left_children = UBJSONUtil.toIntArray(tree.get((Object)"left_children"));
        int[] right_children = UBJSONUtil.toIntArray(tree.get((Object)"right_children"));
        boolean[] default_left = UBJSONUtil.toBooleanArray(tree.get((Object)"default_left"));
        int[] split_indices = UBJSONUtil.toIntArray(tree.get((Object)"split_indices"));
        int[] split_type = UBJSONUtil.toIntArray(tree.get((Object)"split_type"));
        float[] split_conditions = UBJSONUtil.toFloatArray(tree.get((Object)"split_conditions"));
        boolean has_cat = Ints.contains((int[])split_type, (int)1);
        this.nodes = new Node[this.num_nodes];
        for (int i = 0; i < this.num_nodes; ++i) {
            UBObject node = UBValueFactory.createObject();
            node.put("parent", UBValueFactory.createInt((long)parents[i]));
            node.put("left_child", UBValueFactory.createInt((long)left_children[i]));
            node.put("right_child", UBValueFactory.createInt((long)right_children[i]));
            node.put("default_left", (UBValue)UBValueFactory.createBool((boolean)default_left[i]));
            node.put("split_index", UBValueFactory.createInt((long)split_indices[i]));
            node.put("split_type", UBValueFactory.createInt((long)split_type[i]));
            node.put("split_condition", (UBValue)UBValueFactory.createFloat32((float)split_conditions[i]));
            this.nodes[i] = new JSONNode();
            ((UBJSONLoadable)((Object)this.nodes[i])).loadUBJSON(node);
        }
        if (has_cat) {
            int[] categories_segments = UBJSONUtil.toIntArray(tree.get((Object)"categories_segments"));
            int[] categories_sizes = UBJSONUtil.toIntArray(tree.get((Object)"categories_sizes"));
            int[] categories_nodes = UBJSONUtil.toIntArray(tree.get((Object)"categories_nodes"));
            int[] categories = UBJSONUtil.toIntArray(tree.get((Object)"categories"));
            int cnt = 0;
            int last_cat_node = categories_nodes[cnt];
            for (int i = 0; i < this.num_nodes; ++i) {
                JSONNode node = (JSONNode)this.nodes[i];
                if (i == last_cat_node) {
                    int j_begin = categories_segments[cnt];
                    int j_end = j_begin + categories_sizes[cnt];
                    int max_cat = -1;
                    for (int j = j_begin; j < j_end; ++j) {
                        int category = categories[j];
                        max_cat = Math.max(max_cat, category);
                    }
                    if (max_cat == -1) {
                        throw new IllegalArgumentException();
                    }
                    int n_cats = max_cat + 1;
                    BitSet cat_bits = new BitSet(n_cats);
                    for (int j = j_begin; j < j_end; ++j) {
                        int category = categories[j];
                        cat_bits.set(category, true);
                    }
                    node.set_split_categories(cat_bits);
                    if (++cnt == categories_nodes.length) {
                        last_cat_node = -1;
                        continue;
                    }
                    last_cat_node = categories_nodes[cnt];
                    continue;
                }
                node.set_split_categories(null);
            }
        }
    }

    public Float getLeafValue() {
        Node node = this.nodes[0];
        if (!node.is_leaf()) {
            return null;
        }
        return Float.valueOf(node.leaf_value());
    }

    public boolean hasCategoricalSplits() {
        for (int i = 0; i < this.num_nodes; ++i) {
            Node node = this.nodes[i];
            if (node.is_leaf() || node.split_type() != 1) continue;
            return true;
        }
        return false;
    }

    public Set<Integer> getSplitType(int splitIndex) {
        HashSet<Integer> result = new HashSet<Integer>();
        for (int i = 0; i < this.num_nodes; ++i) {
            Node node = this.nodes[i];
            if (node.is_leaf() || node.split_index() != splitIndex) continue;
            result.add(node.split_type());
        }
        return result;
    }

    public BitSet getSplitCategories(int splitIndex) {
        BitSet result = null;
        for (int i = 0; i < this.num_nodes; ++i) {
            BitSet splitCategories;
            Node node = this.nodes[i];
            if (node.is_leaf() || node.split_index() != splitIndex || (splitCategories = node.get_split_categories()) == null) continue;
            if (result == null) {
                result = new BitSet();
            }
            result.or(splitCategories);
        }
        return result;
    }

    public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema) {
        org.dmg.pmml.tree.Node root = this.encodeNode(0, (Predicate)True.INSTANCE, new CategoryManager(), predicateManager, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD).setMathContext(MathContext.FLOAT);
        return treeModel;
    }

    private org.dmg.pmml.tree.Node encodeNode(int index, Predicate predicate, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema) {
        Integer id = index;
        Node node = this.nodes[index];
        if (!node.is_leaf()) {
            Predicate rightPredicate;
            Predicate leftPredicate;
            Float splitValue;
            String name;
            int splitIndex = node.split_index();
            Feature feature = schema.getFeature(splitIndex);
            boolean defaultLeft = node.default_left();
            boolean swapChildren = false;
            CategoryManager leftCategoryManager = categoryManager;
            CategoryManager rightCategoryManager = categoryManager;
            if (feature instanceof CategoricalFeature ? node.split_type() != 1 : node.split_type() != 0) {
                throw new IllegalArgumentException();
            }
            if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                name = categoricalFeature.getName();
                List values = categoricalFeature.getValues();
                splitValue = Float.valueOf(Float.intBitsToFloat(node.split_cond()));
                if (!splitValue.isNaN()) {
                    throw new IllegalArgumentException();
                }
                BitSet split_categories = node.get_split_categories();
                if (split_categories == null) {
                    throw new IllegalArgumentException();
                }
                java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
                ArrayList leftValues = new ArrayList();
                ArrayList rightValues = new ArrayList();
                for (int i = 0; i < values.size(); ++i) {
                    Object value2 = values.get(i);
                    if (!valueFilter.test(value2)) continue;
                    if (!split_categories.get(i)) {
                        leftValues.add(value2);
                        continue;
                    }
                    rightValues.add(value2);
                }
                leftCategoryManager = leftCategoryManager.fork(name, leftValues);
                rightCategoryManager = rightCategoryManager.fork(name, rightValues);
                leftPredicate = predicateManager.createPredicate((Feature)categoricalFeature, leftValues);
                rightPredicate = predicateManager.createPredicate((Feature)categoricalFeature, rightValues);
            } else if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                Object value3 = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value3);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value3);
            } else if (feature instanceof MissingValueFeature) {
                MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                leftPredicate = predicateManager.createSimplePredicate((Feature)missingValueFeature, SimplePredicate.Operator.IS_NOT_MISSING, null);
                rightPredicate = predicateManager.createSimplePredicate((Feature)missingValueFeature, SimplePredicate.Operator.IS_MISSING, null);
            } else if (feature instanceof ThresholdFeature) {
                ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                name = thresholdFeature.getName();
                Object missingValue = thresholdFeature.getMissingValue();
                splitValue = Float.valueOf(Float.intBitsToFloat(node.split_cond()));
                java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
                if (!ValueUtil.isNaN((Object)missingValue)) {
                    valueFilter = valueFilter.and(value -> !ValueUtil.isNaN((Object)value));
                }
                List leftValues = thresholdFeature.getValues(value -> value.floatValue() < splitValue.floatValue()).stream().filter(valueFilter).collect(Collectors.toList());
                List rightValues = thresholdFeature.getValues(value -> value.floatValue() >= splitValue.floatValue()).stream().filter(valueFilter).collect(Collectors.toList());
                leftCategoryManager = leftCategoryManager.fork(name, leftValues);
                rightCategoryManager = rightCategoryManager.fork(name, rightValues);
                leftPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, leftValues, (Object)missingValue, (PredicateManager)predicateManager);
                rightPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, rightValues, (Object)missingValue, (PredicateManager)predicateManager);
                if (!ThresholdFeatureUtil.isMissingValueSafe((Predicate)leftPredicate) && ThresholdFeatureUtil.isMissingValueSafe((Predicate)rightPredicate)) {
                    swapChildren = true;
                }
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                Number splitValue2 = Float.valueOf(Float.intBitsToFloat(node.split_cond()));
                DataType dataType = continuousFeature.getDataType();
                switch (dataType) {
                    case INTEGER: {
                        Float flooredSplitValue = Float.valueOf((float)Math.floor(((Number)splitValue2).floatValue()));
                        if (((Number)splitValue2).floatValue() == flooredSplitValue.floatValue()) {
                            splitValue2 = (int)flooredSplitValue.floatValue();
                            break;
                        }
                        splitValue2 = (int)(flooredSplitValue.floatValue() + 1.0f);
                        break;
                    }
                    case FLOAT: {
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Expected integer or float data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
                    }
                }
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_THAN, (Object)splitValue2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, (Object)splitValue2);
            }
            org.dmg.pmml.tree.Node leftChild = this.encodeNode(node.left_child(), leftPredicate, leftCategoryManager, predicateManager, schema);
            org.dmg.pmml.tree.Node rightChild = this.encodeNode(node.right_child(), rightPredicate, rightCategoryManager, predicateManager, schema);
            org.dmg.pmml.tree.Node result = new BranchNode(null, predicate).setId((Object)id).setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()).addNodes(leftChild, rightChild);
            if (swapChildren) {
                List children = result.getNodes();
                Collections.swap(children, 0, 1);
            }
            return result;
        }
        Float value4 = Float.valueOf(node.leaf_value() + 0.0f);
        SimpleNode result = new LeafNode((Object)value4, predicate).setId((Object)id);
        return result;
    }
}

