/*
 * Decompiled with CFR 0.152.
 */
package sklearn.tree;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import sklearn.Estimator;
import sklearn.tree.HasTree;
import sklearn.tree.Tree;

public class TreeModelUtil {
    private TreeModelUtil() {
    }

    public static <E extends Estimator> List<TreeModel> encodeTreeModelSegmentation(List<E> estimators, MiningFunction miningFunction, Schema schema) {
        PredicateManager predicateManager = new PredicateManager();
        return TreeModelUtil.encodeTreeModelSegmentation(estimators, predicateManager, miningFunction, schema);
    }

    public static <E extends Estimator> List<TreeModel> encodeTreeModelSegmentation(List<E> estimators, final PredicateManager predicateManager, final MiningFunction miningFunction, Schema schema) {
        final Schema segmentSchema = schema.toAnonymousSchema();
        Function function = new Function<E, TreeModel>(){

            public TreeModel apply(E estimator) {
                Schema treeModelSchema = TreeModelUtil.toTreeModelSchema(((Estimator)estimator).getDataType(), segmentSchema);
                return TreeModelUtil.encodeTreeModel(estimator, predicateManager, miningFunction, treeModelSchema);
            }
        };
        return new ArrayList<TreeModel>(Lists.transform(estimators, (Function)function));
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, MiningFunction miningFunction, Schema schema) {
        PredicateManager predicateManager = new PredicateManager();
        return TreeModelUtil.encodeTreeModel(estimator, predicateManager, miningFunction, schema);
    }

    public static <E extends Estimator> TreeModel encodeTreeModel(E estimator, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
        Tree tree = ((HasTree)((Object)estimator)).getTree();
        int[] leftChildren = tree.getChildrenLeft();
        int[] rightChildren = tree.getChildrenRight();
        int[] features = tree.getFeature();
        double[] thresholds = tree.getThreshold();
        double[] values = tree.getValues();
        Node root = new Node().setId("1").setPredicate((Predicate)new True());
        TreeModelUtil.encodeNode(root, predicateManager, 0, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        return treeModel;
    }

    private static void encodeNode(Node node, PredicateManager predicateManager, int index, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, MiningFunction miningFunction, Schema schema) {
        int featureIndex = features[index];
        if (featureIndex >= 0) {
            Predicate rightPredicate;
            Predicate leftPredicate;
            String value;
            Feature feature = schema.getFeature(featureIndex);
            float threshold = (float)thresholds[index];
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (threshold < 0.0f || threshold > 1.0f) {
                    throw new IllegalArgumentException();
                }
                value = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value);
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.FLOAT);
                value = ValueUtil.formatValue((Number)Float.valueOf(threshold));
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
            int leftIndex = leftChildren[index];
            int rightIndex = rightChildren[index];
            Node leftChild = new Node().setId(String.valueOf(leftIndex + 1)).setPredicate(leftPredicate);
            TreeModelUtil.encodeNode(leftChild, predicateManager, leftIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
            Node rightChild = new Node().setId(String.valueOf(rightIndex + 1)).setPredicate(rightPredicate);
            TreeModelUtil.encodeNode(rightChild, predicateManager, rightIndex, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema);
            node.addNodes(new Node[]{leftChild, rightChild});
        } else if (MiningFunction.CLASSIFICATION.equals((Object)miningFunction)) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
            double[] scoreRecordCounts = TreeModelUtil.getRow(values, leftChildren.length, categoricalLabel.size(), index);
            double recordCount = 0.0;
            for (double scoreRecordCount : scoreRecordCounts) {
                recordCount += scoreRecordCount;
            }
            node.setRecordCount(Double.valueOf(recordCount));
            String score = null;
            Double probability = null;
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                String value = categoricalLabel.getValue(i);
                ScoreDistribution scoreDistribution = new ScoreDistribution(value, scoreRecordCounts[i]);
                node.addScoreDistributions(new ScoreDistribution[]{scoreDistribution});
                double scoreProbability = scoreRecordCounts[i] / recordCount;
                if (probability != null && probability.compareTo(scoreProbability) >= 0) continue;
                score = scoreDistribution.getValue();
                probability = scoreProbability;
            }
            node.setScore(score);
        } else if (MiningFunction.REGRESSION.equals((Object)miningFunction)) {
            String score = ValueUtil.formatValue((Number)values[index]);
            node.setScore(score);
        } else {
            throw new IllegalArgumentException();
        }
    }

    public static Schema toTreeModelSchema(final DataType dataType, Schema schema) {
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature(dataType);
                return continuousFeature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    private static double[] getRow(double[] values, int rows, int columns, int row) {
        if (values.length != rows * columns) {
            throw new IllegalArgumentException("Expected " + rows * columns + " element(s), got " + values.length + " element(s)");
        }
        double[] result = new double[columns];
        System.arraycopy(values, row * columns, result, 0, columns);
        return result;
    }
}

