/*
 * Decompiled with CFR 0.152.
 */
package sklearn.tree;

import com.google.common.primitives.Doubles;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.razorvine.pickle.objects.ClassDict;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.HasContinuousDomain;
import org.dmg.pmml.HasExtensions;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.NodeTransformer;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.SimplifyingNodeTransformer;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ScoreDistributionManager;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ThresholdFeatureUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnException;
import sklearn.Estimator;
import sklearn.StepUtil;
import sklearn.VersionUtil;
import sklearn.tree.HasTree;
import sklearn.tree.Tree;
import sklearn.tree.visitors.TreeModelCompactor;
import sklearn.tree.visitors.TreeModelFlattener;
import sklearn.tree.visitors.TreeModelPruner;

public class TreeUtil {
    private TreeUtil() {
    }

    public static <E extends Estimator> boolean hasMissingValueSupport(E estimator) {
        Boolean sklearnAllowNaN = (Boolean)StepUtil.getTag(estimator.getInputTags(), "allow_nan");
        if (sklearnAllowNaN != null) {
            return sklearnAllowNaN;
        }
        String sklearnVersion = estimator.getSkLearnVersion();
        return sklearnVersion != null && VersionUtil.compareVersion(sklearnVersion, "1.3.0") >= 0;
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, MiningFunction miningFunction, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
        Tree tree = ((HasTree)((Object)estimator)).getTree();
        boolean hasMissingValueSupport = ((HasTree)((Object)estimator)).hasMissingValueSupport();
        int[] leftChildren = tree.getChildrenLeft();
        int[] rightChildren = tree.getChildrenRight();
        int[] features = tree.getFeature();
        double[] thresholds = tree.getThreshold();
        double[] values = tree.getValues();
        int[] missingGoToLeft = hasMissingValueSupport ? tree.getMissingGoToLeft() : null;
        Node root = TreeUtil.encodeNode(0, (Predicate)True.INSTANCE, miningFunction, leftChildren, rightChildren, features, thresholds, values, missingGoToLeft, new CategoryManager(), predicateManager, scoreDistributionManager, schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(hasMissingValueSupport ? TreeModel.MissingValueStrategy.DEFAULT_CHILD : TreeModel.MissingValueStrategy.NULL_PREDICTION);
        ClassDictUtil.clearContent((ClassDict)tree);
        return treeModel;
    }

    private static Node encodeNode(int index, Predicate predicate, MiningFunction miningFunction, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, int[] missingGoToLeft, CategoryManager categoryManager, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
        SimpleNode result;
        Integer id = index;
        int featureIndex = features[index];
        if (featureIndex >= 0) {
            ClassifierNode result2;
            Predicate rightPredicate;
            Predicate leftPredicate;
            Object value2;
            Feature feature = schema.getFeature(featureIndex);
            double threshold = thresholds[index];
            CategoryManager leftCategoryManager = categoryManager;
            CategoryManager rightCategoryManager = categoryManager;
            Boolean defaultLeft = null;
            if (missingGoToLeft != null) {
                defaultLeft = missingGoToLeft[index] == 1;
            }
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (threshold < 0.0 || threshold > 1.0) {
                    throw new IllegalArgumentException();
                }
                value2 = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value2);
                if (missingGoToLeft != null) {
                    defaultLeft = Boolean.TRUE;
                }
            } else if (feature instanceof MissingValueFeature) {
                MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                if (threshold != 0.5) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = predicateManager.createSimplePredicate((Feature)missingValueFeature, SimplePredicate.Operator.IS_NOT_MISSING, null);
                rightPredicate = predicateManager.createSimplePredicate((Feature)missingValueFeature, SimplePredicate.Operator.IS_MISSING, null);
            } else if (feature instanceof ThresholdFeature) {
                ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                String name = thresholdFeature.getName();
                Object missingValue = thresholdFeature.getMissingValue();
                java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
                if (!ValueUtil.isNaN((Object)missingValue)) {
                    valueFilter = valueFilter.and(value -> !ValueUtil.isNaN((Object)value));
                }
                List leftValues = thresholdFeature.getValues(value -> TreeUtil.toSplitValue(value) <= threshold).stream().filter(valueFilter).collect(Collectors.toList());
                List rightValues = thresholdFeature.getValues(value -> TreeUtil.toSplitValue(value) > threshold).stream().filter(valueFilter).collect(Collectors.toList());
                leftCategoryManager = leftCategoryManager.fork(name, leftValues);
                rightCategoryManager = rightCategoryManager.fork(name, rightValues);
                leftPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, leftValues, (Object)missingValue, (PredicateManager)predicateManager);
                rightPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, rightValues, (Object)missingValue, (PredicateManager)predicateManager);
            } else {
                ContinuousFeature continuousFeature = TreeUtil.toContinuousFeature(feature);
                value2 = threshold == Double.POSITIVE_INFINITY ? "INF" : Double.valueOf(threshold);
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value2);
            }
            int leftIndex = leftChildren[index];
            int rightIndex = rightChildren[index];
            Node leftChild = TreeUtil.encodeNode(leftIndex, leftPredicate, miningFunction, leftChildren, rightChildren, features, thresholds, values, missingGoToLeft, leftCategoryManager, predicateManager, scoreDistributionManager, schema);
            Node rightChild = TreeUtil.encodeNode(rightIndex, rightPredicate, miningFunction, leftChildren, rightChildren, features, thresholds, values, missingGoToLeft, rightCategoryManager, predicateManager, scoreDistributionManager, schema);
            if (miningFunction == MiningFunction.CLASSIFICATION) {
                result2 = new ClassifierNode(null, predicate);
            } else if (miningFunction == MiningFunction.REGRESSION) {
                double value3 = values[index];
                result2 = new BranchNode((Object)value3, predicate);
            } else {
                throw new IllegalArgumentException();
            }
            result2.setId((Object)id).addNodes(leftChild, rightChild);
            if (defaultLeft != null) {
                result2.setDefaultChild(defaultLeft != false ? leftChild.getId() : rightChild.getId());
            }
            return result2;
        }
        if (miningFunction == MiningFunction.CLASSIFICATION) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
            final double[] leafValues = TreeUtil.getRow(values, leftChildren.length, categoricalLabel.size(), index);
            AbstractList<Number> recordCounts = new AbstractList<Number>(){

                @Override
                public int size() {
                    return leafValues.length;
                }

                @Override
                public Number get(int index) {
                    double leafValue = leafValues[index];
                    return ValueUtil.narrow((double)leafValue);
                }
            };
            double totalRecordCount = 0.0;
            for (Number recordCount : recordCounts) {
                totalRecordCount += recordCount.doubleValue();
            }
            int maxIndex = ScoreDistributionManager.indexOfMax((List)Doubles.asList((double[])leafValues));
            Object score = categoricalLabel.getValue(maxIndex);
            result = new ClassifierNode(score, predicate).setId((Object)id).setRecordCount(ValueUtil.narrow((double)totalRecordCount));
            scoreDistributionManager.addScoreDistributions((PMMLObject)result, categoricalLabel.getValues(), (List)recordCounts, null);
        } else if (miningFunction == MiningFunction.REGRESSION) {
            double value4 = values[index];
            result = new LeafNode((Object)value4, predicate).setId((Object)id);
        } else {
            throw new IllegalArgumentException();
        }
        return result;
    }

    public static <E extends Estimator> Schema configureSchema(E estimator, Schema schema) {
        Boolean numeric = (Boolean)estimator.getOption("numeric", Boolean.TRUE);
        Boolean inputFloat = (Boolean)estimator.getOption("input_float", null);
        return TreeUtil.toTreeModelSchema(numeric, inputFloat, schema);
    }

    public static <E extends Estimator, M extends Model> M configureModel(E estimator, M model) {
        Boolean allowMissing = (Boolean)estimator.getOption("allow_missing", Boolean.FALSE);
        Boolean winnerId = (Boolean)estimator.getOption("winner_id", Boolean.FALSE);
        Map nodeExtensions = (Map)estimator.getOption("node_extensions", null);
        Boolean nodeId = (Boolean)estimator.getOption("node_id", winnerId);
        Boolean nodeScore = (Boolean)estimator.getOption("node_score", winnerId != false ? Boolean.TRUE : null);
        boolean fixed = nodeExtensions != null || nodeId != null && nodeId != false || nodeScore != null && nodeScore != false;
        Boolean compact = (Boolean)estimator.getOption("compact", fixed ? Boolean.FALSE : Boolean.TRUE);
        Boolean flat = (Boolean)estimator.getOption("flat", Boolean.FALSE);
        Boolean prune = (Boolean)estimator.getOption("prune", fixed ? Boolean.FALSE : Boolean.TRUE);
        if (compact.booleanValue() || flat.booleanValue() || prune.booleanValue()) {
            if (fixed) {
                throw new SkLearnException("Conflicting tree model options");
            }
            nodeExtensions = null;
            nodeId = winnerId != false ? Boolean.TRUE : allowMissing;
            Boolean bl = nodeScore = winnerId != false ? Boolean.TRUE : null;
        }
        if (Boolean.TRUE.equals(winnerId)) {
            TreeUtil.encodeNodeId(estimator, model);
        }
        ArrayList<Object> visitors = new ArrayList<Object>();
        if (Boolean.FALSE.equals(allowMissing)) {
            AbstractVisitor defaultChildCleaner = new AbstractVisitor(){

                public VisitorAction visit(TreeModel treeModel) {
                    treeModel.setMissingValueStrategy(null);
                    return super.visit(treeModel);
                }

                public VisitorAction visit(Node node) {
                    Object defaultChild = node.getDefaultChild();
                    if (defaultChild != null) {
                        node.setDefaultChild(null);
                    }
                    return super.visit(node);
                }
            };
            visitors.add(defaultChildCleaner);
        }
        if (Boolean.TRUE.equals(prune)) {
            visitors.add((Object)new TreeModelPruner());
        }
        if (Boolean.TRUE.equals(compact)) {
            visitors.add((Object)new TreeModelCompactor());
        }
        if (Boolean.TRUE.equals(flat)) {
            visitors.add((Object)new TreeModelFlattener());
        }
        if (nodeExtensions != null) {
            Set entries = nodeExtensions.entrySet();
            for (Map.Entry entry : entries) {
                String name = (String)entry.getKey();
                final Map values = (Map)entry.getValue();
                AbstractExtender nodeExtender = new AbstractExtender(name){
                    private NodeTransformer nodeTransformer;
                    {
                        super(arg0);
                        this.nodeTransformer = SimplifyingNodeTransformer.INSTANCE;
                    }

                    public VisitorAction visit(TreeModel treeModel) {
                        treeModel.setNode(this.ensureExtensibility(treeModel.getNode()));
                        return super.visit(treeModel);
                    }

                    public VisitorAction visit(Node node) {
                        Object value;
                        if (node.hasNodes()) {
                            List children = node.getNodes();
                            ListIterator<Node> childIt = children.listIterator();
                            while (childIt.hasNext()) {
                                childIt.set(this.ensureExtensibility((Node)childIt.next()));
                            }
                        }
                        if ((value = this.getValue(node)) != null) {
                            value = ScalarUtil.decode((Object)value);
                            this.addExtension((PMMLObject)((Node)((HasExtensions)node)), ValueUtil.asString((Object)value));
                        }
                        return super.visit(node);
                    }

                    private Node ensureExtensibility(Node node) {
                        if (node instanceof HasExtensions) {
                            return node;
                        }
                        Object value = this.getValue(node);
                        if (value != null) {
                            return this.nodeTransformer.toComplexNode(node);
                        }
                        return node;
                    }

                    private Object getValue(Node node) {
                        Integer id = ValueUtil.asInteger((Number)((Number)node.getId()));
                        return values.get(id);
                    }
                };
                visitors.add(nodeExtender);
            }
        }
        if (Boolean.FALSE.equals(nodeId)) {
            AbstractVisitor nodeIdCleaner = new AbstractVisitor(){

                public VisitorAction visit(Node node) {
                    node.setId(null);
                    return super.visit(node);
                }
            };
            visitors.add(nodeIdCleaner);
        }
        if (Boolean.FALSE.equals(nodeScore)) {
            AbstractVisitor nodeScoreCleaner = new AbstractVisitor(){

                public VisitorAction visit(Node node) {
                    if (node.hasNodes()) {
                        node.setScore(null);
                        if (node.hasScoreDistributions()) {
                            List scoreDistributions = node.getScoreDistributions();
                            scoreDistributions.clear();
                        }
                    }
                    return super.visit(node);
                }
            };
            visitors.add(nodeScoreCleaner);
        }
        for (Visitor visitor : visitors) {
            visitor.applyTo(model);
        }
        return model;
    }

    static Schema toTreeModelSchema(final Boolean numeric, final Boolean inputFloat, Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof MissingValueFeature) {
                    MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                    return missingValueFeature;
                }
                if (feature instanceof ThresholdFeature && numeric != null && !numeric.booleanValue()) {
                    ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                    return thresholdFeature;
                }
                if (inputFloat != null && inputFloat.booleanValue()) {
                    ContinuousFeature continuousFeature = feature.toContinuousFeature();
                    DataType dataType = continuousFeature.getDataType();
                    if (dataType != DataType.FLOAT) {
                        HasContinuousDomain hasContinuousDomain;
                        Field field = continuousFeature.getField();
                        field.setDataType(DataType.FLOAT);
                        if (field instanceof HasContinuousDomain && (hasContinuousDomain = (HasContinuousDomain)field).hasIntervals()) {
                            List intervals = hasContinuousDomain.getIntervals();
                            for (Interval interval : intervals) {
                                Number leftMargin = interval.getLeftMargin();
                                Number rightMargin = interval.getRightMargin();
                                if (leftMargin != null) {
                                    interval.setLeftMargin((Number)leftMargin.floatValue());
                                }
                                if (rightMargin == null) continue;
                                interval.setRightMargin((Number)rightMargin.floatValue());
                            }
                        }
                        return new ContinuousFeature(continuousFeature.getEncoder(), field);
                    }
                    return continuousFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.FLOAT);
                return continuousFeature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    static Schema toTreeModelFeatureImportanceSchema(Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof MissingValueFeature) {
                    MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                    return missingValueFeature;
                }
                if (feature instanceof ThresholdFeature) {
                    ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                    return thresholdFeature;
                }
                ContinuousFeature continuousFeature = TreeUtil.toContinuousFeature(feature);
                return continuousFeature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    private static ContinuousFeature toContinuousFeature(Feature feature) {
        return feature.toContinuousFeature(DataType.FLOAT).toContinuousFeature(DataType.DOUBLE);
    }

    private static double toSplitValue(Number value) {
        return value.floatValue();
    }

    private static double[] getRow(double[] values, int rows, int columns, int row) {
        if (values.length != rows * columns) {
            throw new IllegalArgumentException("Expected " + rows * columns + " element(s), got " + values.length + " element(s)");
        }
        double[] result = new double[columns];
        System.arraycopy(values, row * columns, result, 0, columns);
        return result;
    }

    private static void encodeNodeId(Estimator estimator, Model model) {
        if (model instanceof TreeModel) {
            TreeModel treeModel = (TreeModel)model;
            estimator.encodeApplyOutput((Model)treeModel, DataType.INTEGER);
        } else if (model instanceof MiningModel) {
            MiningModel miningModel = (MiningModel)model;
            Segmentation segmentation = miningModel.requireSegmentation();
            List segments = segmentation.requireSegments();
            ArrayList<String> segmentIds = new ArrayList<String>();
            for (Segment segment : segments) {
                TreeModel treeModel = (TreeModel)segment.requireModel(TreeModel.class);
                String segmentId = segment.getId();
                if (segmentId == null) {
                    throw new UnsupportedElementException((PMMLObject)segment);
                }
                segmentIds.add(segmentId);
            }
            estimator.encodeMultiApplyOutput((Model)miningModel, DataType.INTEGER, segmentIds);
        } else {
            throw new IllegalArgumentException();
        }
    }
}

