/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.CategoricalSplit;
import org.apache.spark.ml.tree.ContinuousSplit;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ComplexNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.sparkml.visitors.TreeModelCompactor;

public class TreeModelUtil {
    private static final double[] TRUE = new double[]{1.0};
    private static final double[] FALSE = new double[]{0.0};

    private TreeModelUtil() {
    }

    public static <C extends ModelConverter<? extends M>, M extends Model<M>> TreeModel encodeDecisionTree(C converter, Schema schema) {
        PredicateManager predicateManager = new PredicateManager();
        return TreeModelUtil.encodeDecisionTree(converter, predicateManager, schema);
    }

    public static <C extends ModelConverter<? extends M>, M extends Model<M>> TreeModel encodeDecisionTree(C converter, PredicateManager predicateManager, Schema schema) {
        return TreeModelUtil.encodeDecisionTree(converter, (Model)converter.getTransformer(), predicateManager, schema);
    }

    public static <C extends ModelConverter<? extends M>, M extends Model<M>, T extends Model<T>> List<TreeModel> encodeDecisionTreeEnsemble(C converter, Schema schema) {
        PredicateManager predicateManager = new PredicateManager();
        return TreeModelUtil.encodeDecisionTreeEnsemble(converter, predicateManager, schema);
    }

    public static <C extends ModelConverter<? extends M>, M extends Model<M>, T extends Model<T>> List<TreeModel> encodeDecisionTreeEnsemble(C converter, PredicateManager predicateManager, Schema schema) {
        Model[] trees;
        Model model = (Model)converter.getTransformer();
        Schema segmentSchema = schema.toAnonymousSchema();
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>();
        for (Model tree : trees = (Model[])((TreeEnsembleModel)model).trees()) {
            TreeModel treeModel = TreeModelUtil.encodeDecisionTree(converter, tree, predicateManager, segmentSchema);
            treeModels.add(treeModel);
        }
        return treeModels;
    }

    private static <M extends Model<M>> TreeModel encodeDecisionTree(ModelConverter<?> converter, M model, PredicateManager predicateManager, final Schema schema) {
        TreeModel treeModel;
        ScoreEncoder scoreEncoder;
        if (model instanceof DecisionTreeRegressionModel) {
            scoreEncoder = new ScoreEncoder(){

                @Override
                public org.dmg.pmml.tree.Node encode(org.dmg.pmml.tree.Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
                    node.setScore((Object)leafNode.prediction());
                    return node;
                }
            };
            treeModel = TreeModelUtil.encodeTreeModel(model, predicateManager, MiningFunction.REGRESSION, scoreEncoder, schema);
        } else if (model instanceof DecisionTreeClassificationModel) {
            scoreEncoder = new ScoreEncoder(){
                private CategoricalLabel categoricalLabel;
                {
                    this.categoricalLabel = (CategoricalLabel)schema.getLabel();
                }

                @Override
                public org.dmg.pmml.tree.Node encode(org.dmg.pmml.tree.Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
                    node = new ComplexNode().setPredicate(node.getPredicate());
                    int index = ValueUtil.asInt((Number)leafNode.prediction());
                    node.setScore((Object)this.categoricalLabel.getValue(index));
                    ImpurityCalculator impurityCalculator = leafNode.impurityStats();
                    node.setRecordCount(Double.valueOf(impurityCalculator.count()));
                    List scoreDistributions = node.getScoreDistributions();
                    double[] stats = impurityCalculator.stats();
                    for (int i = 0; i < stats.length; ++i) {
                        ScoreDistribution scoreDistribution = new ScoreDistribution(this.categoricalLabel.getValue(i), stats[i]);
                        scoreDistributions.add(scoreDistribution);
                    }
                    return node;
                }
            };
            treeModel = TreeModelUtil.encodeTreeModel(model, predicateManager, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
        } else {
            throw new IllegalArgumentException();
        }
        Boolean compact = (Boolean)converter.getOption("compact", Boolean.TRUE);
        if (compact != null && compact.booleanValue()) {
            TreeModelCompactor visitor = new TreeModelCompactor();
            visitor.applyTo((Visitable)treeModel);
        }
        return treeModel;
    }

    private static <M extends Model<M>> TreeModel encodeTreeModel(M model, PredicateManager predicateManager, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema) {
        org.dmg.pmml.tree.Node root = TreeModelUtil.encodeNode((Predicate)new True(), ((DecisionTreeModel)model).rootNode(), predicateManager, new CategoryManager(), scoreEncoder, schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        return treeModel;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static org.dmg.pmml.tree.Node encodeNode(Predicate predicate, Node sparkNode, PredicateManager predicateManager, CategoryManager categoryManager, ScoreEncoder scoreEncoder, Schema schema) {
        Predicate rightPredicate;
        Predicate leftPredicate;
        if (sparkNode instanceof org.apache.spark.ml.tree.LeafNode) {
            org.apache.spark.ml.tree.LeafNode leafNode = (org.apache.spark.ml.tree.LeafNode)sparkNode;
            SimpleNode result = new LeafNode().setPredicate(predicate);
            return scoreEncoder.encode((org.dmg.pmml.tree.Node)result, leafNode);
        }
        if (!(sparkNode instanceof InternalNode)) throw new IllegalArgumentException();
        InternalNode internalNode = (InternalNode)sparkNode;
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit)split;
            double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature)feature;
                if (threshold != 0.5) {
                    throw new IllegalArgumentException("Invalid split threshold value " + threshold + " for a boolean feature");
                }
                leftPredicate = predicateManager.createSimplePredicate((Feature)booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate((Feature)booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String value = ValueUtil.formatValue((Number)threshold);
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
        } else {
            if (!(split instanceof CategoricalSplit)) throw new IllegalArgumentException();
            CategoricalSplit categoricalSplit = (CategoricalSplit)split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                SimplePredicate.Operator rightOperator;
                SimplePredicate.Operator leftOperator;
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else {
                    if (!Arrays.equals(FALSE, leftCategories)) throw new IllegalArgumentException();
                    if (!Arrays.equals(TRUE, rightCategories)) throw new IllegalArgumentException();
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                }
                String value = ValueUtil.formatValue((Object)binaryFeature.getValue());
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, rightOperator, value);
            } else {
                if (!(feature instanceof CategoricalFeature)) throw new IllegalArgumentException();
                CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                FieldName name = categoricalFeature.getName();
                List values = categoricalFeature.getValues();
                if (values.size() != leftCategories.length + rightCategories.length) {
                    throw new IllegalArgumentException();
                }
                java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
                List<String> leftValues = TreeModelUtil.selectValues(values, leftCategories, valueFilter);
                List<String> rightValues = TreeModelUtil.selectValues(values, rightCategories, valueFilter);
                leftCategoryManager = categoryManager.fork(name, leftValues);
                rightCategoryManager = categoryManager.fork(name, rightValues);
                leftPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, rightValues);
            }
        }
        org.dmg.pmml.tree.Node leftChild = TreeModelUtil.encodeNode(leftPredicate, internalNode.leftChild(), predicateManager, leftCategoryManager, scoreEncoder, schema);
        org.dmg.pmml.tree.Node rightChild = TreeModelUtil.encodeNode(rightPredicate, internalNode.rightChild(), predicateManager, rightCategoryManager, scoreEncoder, schema);
        return new BranchNode().setPredicate(predicate).addNodes(new org.dmg.pmml.tree.Node[]{leftChild, rightChild});
    }

    private static List<String> selectValues(List<String> values, double[] categories, java.util.function.Predicate<String> valueFilter) {
        if (categories.length == 1) {
            int index = ValueUtil.asInt((Number)categories[0]);
            String value = values.get(index);
            if (valueFilter.test(value)) {
                return Collections.singletonList(value);
            }
            return Collections.emptyList();
        }
        ArrayList<String> result = new ArrayList<String>(categories.length);
        for (int i = 0; i < categories.length; ++i) {
            int index = ValueUtil.asInt((Number)categories[i]);
            String value = values.get(index);
            if (!valueFilter.test(value)) continue;
            result.add(value);
        }
        return result;
    }

    static interface ScoreEncoder {
        public org.dmg.pmml.tree.Node encode(org.dmg.pmml.tree.Node var1, org.apache.spark.ml.tree.LeafNode var2);
    }
}

