/*
 * Decompiled with CFR 0.152.
 */
package sklearn.tree;

import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.HasExtensions;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.DefaultNodeTransformer;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.NodeTransformer;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.visitors.TreeModelCompactor;
import org.jpmml.sklearn.visitors.TreeModelFlattener;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;
import sklearn.tree.HasTree;
import sklearn.tree.ScoreDistributionManager;
import sklearn.tree.Tree;

public class TreeModelUtil {
    private TreeModelUtil() {
    }

    public static <E extends Estimator, M extends Model> M transform(E estimator, M model) {
        Boolean winnerId = (Boolean)estimator.getOption("winner_id", Boolean.FALSE);
        Map nodeExtensions = (Map)estimator.getOption("node_extensions", null);
        Boolean nodeId = (Boolean)estimator.getOption("node_id", winnerId);
        boolean fixed = nodeExtensions != null || nodeId != false;
        Boolean compact = (Boolean)estimator.getOption("compact", fixed ? Boolean.FALSE : Boolean.TRUE);
        Boolean flat = (Boolean)estimator.getOption("flat", Boolean.FALSE);
        if (fixed && (compact.booleanValue() || flat.booleanValue())) {
            throw new IllegalArgumentException("Conflicting tree model options");
        }
        if (Boolean.TRUE.equals(winnerId)) {
            Output output = ModelUtil.ensureOutput(model);
            OutputField nodeIdField = ModelUtil.createEntityIdField((FieldName)FieldName.create((String)"nodeId")).setDataType(DataType.INTEGER);
            output.addOutputFields(new OutputField[]{nodeIdField});
        }
        ArrayList<Object> visitors = new ArrayList<Object>();
        if (Boolean.TRUE.equals(compact)) {
            visitors.add((Object)new TreeModelCompactor());
        }
        if (Boolean.TRUE.equals(flat)) {
            visitors.add((Object)new TreeModelFlattener());
        }
        if (nodeExtensions != null) {
            Set entries = nodeExtensions.entrySet();
            for (Map.Entry entry : entries) {
                String name = (String)entry.getKey();
                final Map values = (Map)entry.getValue();
                AbstractExtender nodeExtender = new AbstractExtender(name){
                    private NodeTransformer nodeTransformer;
                    {
                        super(x0);
                        this.nodeTransformer = DefaultNodeTransformer.INSTANCE;
                    }

                    public VisitorAction visit(TreeModel treeModel) {
                        treeModel.setNode(this.ensureExtensibility(treeModel.getNode()));
                        return super.visit(treeModel);
                    }

                    public VisitorAction visit(Node node) {
                        Object value;
                        if (node.hasNodes()) {
                            List children = node.getNodes();
                            ListIterator<Node> childIt = children.listIterator();
                            while (childIt.hasNext()) {
                                childIt.set(this.ensureExtensibility((Node)childIt.next()));
                            }
                        }
                        if ((value = this.getValue(node)) != null) {
                            value = ScalarUtil.decode(value);
                            this.addExtension((PMMLObject)((Node)((HasExtensions)node)), org.jpmml.model.ValueUtil.toString((Object)value));
                        }
                        return super.visit(node);
                    }

                    private Node ensureExtensibility(Node node) {
                        if (node instanceof HasExtensions) {
                            return node;
                        }
                        Object value = this.getValue(node);
                        if (value != null) {
                            return this.nodeTransformer.toComplexNode(node);
                        }
                        return node;
                    }

                    private Object getValue(Node node) {
                        Integer id = ValueUtil.asInteger((Number)((Number)node.getId()));
                        return values.get(id);
                    }
                };
                visitors.add(nodeExtender);
            }
        }
        if (Boolean.FALSE.equals(nodeId)) {
            AbstractVisitor nodeIdCleaner = new AbstractVisitor(){

                public VisitorAction visit(Node node) {
                    node.setId(null);
                    return super.visit(node);
                }
            };
            visitors.add(nodeIdCleaner);
        }
        for (Visitor visitor : visitors) {
            visitor.applyTo(model);
        }
        return model;
    }

    public static <E extends Estimator, T extends Estimator> List<TreeModel> encodeTreeModelSegmentation(E estimator, MiningFunction miningFunction, Schema schema) {
        PredicateManager predicateManager = new PredicateManager();
        ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
        return TreeModelUtil.encodeTreeModelSegmentation(estimator, predicateManager, scoreDistributionManager, miningFunction, schema);
    }

    public static <E extends Estimator, T extends Estimator> List<TreeModel> encodeTreeModelSegmentation(E estimator, final PredicateManager predicateManager, final ScoreDistributionManager scoreDistributionManager, final MiningFunction miningFunction, Schema schema) {
        List estimators = ((HasEstimatorEnsemble)((Object)estimator)).getEstimators();
        final Schema segmentSchema = schema.toAnonymousSchema();
        Function function = new Function<T, TreeModel>(){

            @Override
            public TreeModel apply(T estimator) {
                Schema treeModelSchema = TreeModelUtil.toTreeModelSchema(((Estimator)estimator).getDataType(), segmentSchema);
                return TreeModelUtil.encodeTreeModel(estimator, predicateManager, scoreDistributionManager, miningFunction, treeModelSchema);
            }
        };
        return estimators.stream().map(function).collect(Collectors.toList());
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, MiningFunction miningFunction, Schema schema) {
        PredicateManager predicateManager = new PredicateManager();
        ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
        return TreeModelUtil.encodeTreeModel(estimator, predicateManager, scoreDistributionManager, miningFunction, schema);
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, MiningFunction miningFunction, Schema schema) {
        Tree tree = ((HasTree)((Object)estimator)).getTree();
        int[] leftChildren = tree.getChildrenLeft();
        int[] rightChildren = tree.getChildrenRight();
        int[] features = tree.getFeature();
        double[] thresholds = tree.getThreshold();
        double[] values = tree.getValues();
        Node root = TreeModelUtil.encodeNode((Predicate)True.INSTANCE, predicateManager, scoreDistributionManager, 0, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        ClassDictUtil.clearContent(tree);
        return treeModel;
    }

    private static Node encodeNode(Predicate predicate, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, int index, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, MiningFunction miningFunction, Schema schema) {
        SimpleNode result;
        Integer id = index;
        int featureIndex = features[index];
        if (featureIndex >= 0) {
            ClassifierNode result2;
            Predicate rightPredicate;
            Predicate leftPredicate;
            Object value;
            Feature feature = schema.getFeature(featureIndex);
            double threshold = thresholds[index];
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (threshold < 0.0 || threshold > 1.0) {
                    throw new IllegalArgumentException();
                }
                value = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value);
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.FLOAT).toContinuousFeature(DataType.DOUBLE);
                value = threshold;
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
            int leftIndex = leftChildren[index];
            int rightIndex = rightChildren[index];
            Node leftChild = TreeModelUtil.encodeNode(leftPredicate, predicateManager, scoreDistributionManager, leftIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
            Node rightChild = TreeModelUtil.encodeNode(rightPredicate, predicateManager, scoreDistributionManager, rightIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
            if (MiningFunction.CLASSIFICATION.equals((Object)miningFunction)) {
                result2 = new ClassifierNode(null, predicate);
            } else if (MiningFunction.REGRESSION.equals((Object)miningFunction)) {
                result2 = new CountingBranchNode(null, predicate);
            } else {
                throw new IllegalArgumentException();
            }
            result2.setId((Object)id).addNodes(leftChild, rightChild);
            return result2;
        }
        if (MiningFunction.CLASSIFICATION.equals((Object)miningFunction)) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
            double[] recordCounts = TreeModelUtil.getRow(values, leftChildren.length, categoricalLabel.size(), index);
            double totalRecordCount = 0.0;
            Object score = null;
            double scoreRecordCount = -1.7976931348623157E308;
            for (int i = 0; i < recordCounts.length; ++i) {
                double recordCount = recordCounts[i];
                totalRecordCount += recordCount;
                if (!(recordCount > scoreRecordCount)) continue;
                score = categoricalLabel.getValue(i);
                scoreRecordCount = recordCount;
            }
            result = new ClassifierNode(score, predicate).setId((Object)id).setRecordCount((Number)totalRecordCount);
            List<ScoreDistribution> scoreDistributions = scoreDistributionManager.createScoreDistribution(categoricalLabel, recordCounts);
            result.getScoreDistributions().addAll(scoreDistributions);
        } else if (MiningFunction.REGRESSION.equals((Object)miningFunction)) {
            double value = values[index];
            result = new CountingLeafNode((Object)value, predicate).setId((Object)id);
        } else {
            throw new IllegalArgumentException();
        }
        return result;
    }

    public static Schema toTreeModelSchema(final DataType dataType, Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature(dataType);
                return continuousFeature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    private static double[] getRow(double[] values, int rows, int columns, int row) {
        if (values.length != rows * columns) {
            throw new IllegalArgumentException("Expected " + rows * columns + " element(s), got " + values.length + " element(s)");
        }
        double[] result = new double[columns];
        System.arraycopy(values, row * columns, result, 0, columns);
        return result;
    }
}

