/*
 * Decompiled with CFR 0.152.
 */
package sklearn.tree;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.Node;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import sklearn.Estimator;
import sklearn.tree.HasTree;
import sklearn.tree.Tree;

public class TreeModelUtil {
    private TreeModelUtil() {
    }

    public static <E extends Estimator> List<TreeModel> encodeTreeModelSegmentation(List<E> estimators, final MiningFunctionType miningFunction, final Schema schema) {
        Function function = new Function<E, TreeModel>(){
            private Schema segmentSchema;
            {
                this.segmentSchema = schema.toAnonymousSchema();
            }

            public TreeModel apply(E estimator) {
                return TreeModelUtil.encodeTreeModel(estimator, miningFunction, this.segmentSchema);
            }
        };
        return new ArrayList<TreeModel>(Lists.transform(estimators, (Function)function));
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, MiningFunctionType miningFunction, Schema schema) {
        Tree tree = ((HasTree)((Object)estimator)).getTree();
        int[] leftChildren = tree.getChildrenLeft();
        int[] rightChildren = tree.getChildrenRight();
        int[] features = tree.getFeature();
        double[] thresholds = tree.getThreshold();
        double[] values = tree.getValues();
        Node root = new Node().setId("1").setPredicate((Predicate)new True());
        TreeModelUtil.encodeNode(root, 0, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Schema)schema), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        return treeModel;
    }

    private static void encodeNode(Node node, int index, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, MiningFunctionType miningFunction, Schema schema) {
        int featureIndex = features[index];
        if (featureIndex >= 0) {
            SimplePredicate rightPredicate;
            SimplePredicate leftPredicate;
            Feature feature = schema.getFeature(featureIndex);
            float threshold = (float)thresholds[index];
            if (feature instanceof ContinuousFeature) {
                ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                String value = ValueUtil.formatValue((Number)Float.valueOf(threshold));
                leftPredicate = new SimplePredicate(continuousFeature.getName(), SimplePredicate.Operator.LESS_OR_EQUAL).setValue(value);
                rightPredicate = new SimplePredicate(continuousFeature.getName(), SimplePredicate.Operator.GREATER_THAN).setValue(value);
            } else if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (threshold < 0.0f || threshold > 1.0f) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.NOT_EQUAL).setValue(binaryFeature.getValue());
                rightPredicate = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.EQUAL).setValue(binaryFeature.getValue());
            } else {
                throw new IllegalArgumentException();
            }
            int leftIndex = leftChildren[index];
            int rightIndex = rightChildren[index];
            Node leftChild = new Node().setId(String.valueOf(leftIndex + 1)).setPredicate((Predicate)leftPredicate);
            TreeModelUtil.encodeNode(leftChild, leftIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
            Node rightChild = new Node().setId(String.valueOf(rightIndex + 1)).setPredicate((Predicate)rightPredicate);
            TreeModelUtil.encodeNode(rightChild, rightIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
            node.addNodes(new Node[]{leftChild, rightChild});
        } else if (MiningFunctionType.CLASSIFICATION.equals((Object)miningFunction)) {
            List targetCategories = schema.getTargetCategories();
            double[] scoreRecordCounts = TreeModelUtil.getRow(values, leftChildren.length, targetCategories.size(), index);
            double recordCount = 0.0;
            for (double scoreRecordCount : scoreRecordCounts) {
                recordCount += scoreRecordCount;
            }
            node.setRecordCount(Double.valueOf(recordCount));
            String score = null;
            Double probability = null;
            for (int i = 0; i < targetCategories.size(); ++i) {
                String targetCategory = (String)targetCategories.get(i);
                ScoreDistribution scoreDistribution = new ScoreDistribution(targetCategory, scoreRecordCounts[i]);
                node.addScoreDistributions(new ScoreDistribution[]{scoreDistribution});
                double scoreProbability = scoreRecordCounts[i] / recordCount;
                if (probability != null && probability.compareTo(scoreProbability) >= 0) continue;
                score = scoreDistribution.getValue();
                probability = scoreProbability;
            }
            node.setScore(score);
        } else if (MiningFunctionType.REGRESSION.equals((Object)miningFunction)) {
            String score = ValueUtil.formatValue((Number)values[index]);
            node.setScore(score);
        } else {
            throw new IllegalArgumentException();
        }
    }

    private static double[] getRow(double[] values, int rows, int columns, int row) {
        if (values.length != rows * columns) {
            throw new IllegalArgumentException();
        }
        double[] result = new double[columns];
        System.arraycopy(values, row * columns, result, 0, columns);
        return result;
    }
}

