/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.hist_gradient_boosting;

import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import sklearn.ensemble.hist_gradient_boosting.TreePredictor;

public class TreePredictorUtil {
    private TreePredictorUtil() {
    }

    public static TreeModel encodeTreeModel(TreePredictor treePredictor, Schema schema) {
        PredicateManager predicateManager = new PredicateManager();
        return TreePredictorUtil.encodeTreeModel(treePredictor, predicateManager, schema);
    }

    public static TreeModel encodeTreeModel(TreePredictor treePredictor, PredicateManager predicateManager, Schema schema) {
        int[] leaf = treePredictor.isLeaf();
        int[] leftChildren = treePredictor.getLeft();
        int[] rightChildren = treePredictor.getRight();
        int[] featureIdx = treePredictor.getFeatureIdx();
        double[] thresholds = treePredictor.getThreshold();
        int[] missingGoToLeft = treePredictor.getMissingGoToLeft();
        double[] values = treePredictor.getValues();
        Node root = TreePredictorUtil.encodeNode(0, (Predicate)True.INSTANCE, leaf, leftChildren, rightChildren, featureIdx, thresholds, missingGoToLeft, values, predicateManager, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
        return treeModel;
    }

    private static Node encodeNode(int index, Predicate predicate, int[] leaf, int[] leftChildren, int[] rightChildren, int[] featureIdx, double[] thresholds, int[] missingGoToLeft, double[] values, PredicateManager predicateManager, Schema schema) {
        Integer id = index;
        if (leaf[index] == 0) {
            boolean defaultLeft;
            Predicate rightPredicate;
            Predicate leftPredicate;
            Object value;
            Feature feature = schema.getFeature(featureIdx[index]);
            double threshold = thresholds[index];
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (threshold != 0.5) {
                    throw new IllegalArgumentException();
                }
                value = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value);
                defaultLeft = true;
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.DOUBLE);
                value = threshold;
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
                defaultLeft = missingGoToLeft[index] == 1;
            }
            Node leftChild = TreePredictorUtil.encodeNode(leftChildren[index], leftPredicate, leaf, leftChildren, rightChildren, featureIdx, thresholds, missingGoToLeft, values, predicateManager, schema);
            Node rightChild = TreePredictorUtil.encodeNode(rightChildren[index], rightPredicate, leaf, leftChildren, rightChildren, featureIdx, thresholds, missingGoToLeft, values, predicateManager, schema);
            Node result = new BranchNode(null, predicate).setId((Object)id).setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()).addNodes(leftChild, rightChild);
            return result;
        }
        if (leaf[index] == 1) {
            SimpleNode result = new LeafNode((Object)values[index], predicate).setId((Object)id);
            return result;
        }
        throw new IllegalArgumentException();
    }
}

