/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.CategoricalSplit;
import org.apache.spark.ml.tree.ContinuousSplit;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.LeafNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.ml.tree.TreeEnsembleModel;
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator;
import org.dmg.pmml.Array;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ListFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sparkml.BooleanFeature;

public class TreeModelUtil {
    private static final double[] TRUE = new double[]{1.0};
    private static final double[] FALSE = new double[]{0.0};

    private TreeModelUtil() {
    }

    public static TreeModel encodeDecisionTree(DecisionTreeModel model, Schema schema) {
        Node node = model.rootNode();
        if (model instanceof DecisionTreeRegressionModel) {
            return TreeModelUtil.encodeTreeModel(MiningFunctionType.REGRESSION, node, schema);
        }
        if (model instanceof DecisionTreeClassificationModel) {
            return TreeModelUtil.encodeTreeModel(MiningFunctionType.CLASSIFICATION, node, schema);
        }
        throw new IllegalArgumentException();
    }

    public static List<TreeModel> encodeDecisionTreeEnsemble(TreeEnsembleModel model, final Schema schema) {
        Function<DecisionTreeModel, TreeModel> function = new Function<DecisionTreeModel, TreeModel>(){
            private Schema segmentSchema;
            {
                this.segmentSchema = schema.toAnonymousSchema();
            }

            public TreeModel apply(DecisionTreeModel model) {
                return TreeModelUtil.encodeDecisionTree(model, this.segmentSchema);
            }
        };
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>(Lists.transform(Arrays.asList(model.trees()), (Function)function));
        return treeModels;
    }

    public static TreeModel encodeTreeModel(MiningFunctionType miningFunction, Node node, Schema schema) {
        org.dmg.pmml.Node root = TreeModelUtil.encodeNode(miningFunction, node, schema).setPredicate((Predicate)new True());
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Schema)schema), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        return treeModel;
    }

    public static void scalePredictions(TreeModel treeModel, final double weight) {
        if (ValueUtil.isOne((Number)weight)) {
            return;
        }
        AbstractVisitor visitor = new AbstractVisitor(){

            public VisitorAction visit(org.dmg.pmml.Node node) {
                double score = Double.parseDouble(node.getScore());
                node.setScore(ValueUtil.formatValue((Number)(score * weight)));
                return super.visit(node);
            }
        };
        visitor.applyTo((Visitable)treeModel);
    }

    public static org.dmg.pmml.Node encodeNode(MiningFunctionType miningFunction, Node node, Schema schema) {
        if (node instanceof InternalNode) {
            return TreeModelUtil.encodeInternalNode(miningFunction, (InternalNode)node, schema);
        }
        if (node instanceof LeafNode) {
            return TreeModelUtil.encodeLeafNode(miningFunction, (LeafNode)node, schema);
        }
        throw new IllegalArgumentException();
    }

    private static org.dmg.pmml.Node encodeInternalNode(MiningFunctionType miningFunction, InternalNode internalNode, Schema schema) {
        org.dmg.pmml.Node result = TreeModelUtil.createNode(miningFunction, (Node)internalNode, schema);
        Predicate[] predicates = TreeModelUtil.encodeSplit(internalNode.split(), schema);
        org.dmg.pmml.Node leftChild = TreeModelUtil.encodeNode(miningFunction, internalNode.leftChild(), schema).setPredicate(predicates[0]);
        org.dmg.pmml.Node rightChild = TreeModelUtil.encodeNode(miningFunction, internalNode.rightChild(), schema).setPredicate(predicates[1]);
        result.addNodes(new org.dmg.pmml.Node[]{leftChild, rightChild});
        return result;
    }

    private static org.dmg.pmml.Node encodeLeafNode(MiningFunctionType miningFunction, LeafNode leafNode, Schema schema) {
        org.dmg.pmml.Node result = TreeModelUtil.createNode(miningFunction, (Node)leafNode, schema);
        return result;
    }

    private static org.dmg.pmml.Node createNode(MiningFunctionType miningFunction, Node node, Schema schema) {
        org.dmg.pmml.Node result = new org.dmg.pmml.Node();
        switch (miningFunction) {
            case REGRESSION: {
                String score = ValueUtil.formatValue((Number)node.prediction());
                result.setScore(score);
                break;
            }
            case CLASSIFICATION: {
                List targetCategories = schema.getTargetCategories();
                if (targetCategories == null) {
                    throw new IllegalArgumentException();
                }
                int index = ValueUtil.asInt((Number)node.prediction());
                result.setScore((String)targetCategories.get(index));
                ImpurityCalculator impurityCalculator = node.impurityStats();
                result.setRecordCount(Double.valueOf(impurityCalculator.count()));
                double[] stats = impurityCalculator.stats();
                for (int i = 0; i < stats.length; ++i) {
                    if (stats[i] == 0.0) continue;
                    ScoreDistribution scoreDistribution = new ScoreDistribution((String)targetCategories.get(i), stats[i]);
                    result.addScoreDistributions(new ScoreDistribution[]{scoreDistribution});
                }
                break;
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
        return result;
    }

    private static Predicate[] encodeSplit(Split split, Schema schema) {
        if (split instanceof ContinuousSplit) {
            return TreeModelUtil.encodeContinuousSplit((ContinuousSplit)split, schema);
        }
        if (split instanceof CategoricalSplit) {
            return TreeModelUtil.encodeCategoricalSplit((CategoricalSplit)split, schema);
        }
        throw new IllegalArgumentException();
    }

    private static Predicate[] encodeContinuousSplit(ContinuousSplit continuousSplit, Schema schema) {
        ContinuousFeature feature = (ContinuousFeature)schema.getFeature(continuousSplit.featureIndex());
        double threshold = continuousSplit.threshold();
        if (feature instanceof BooleanFeature) {
            BooleanFeature booleanFeature = (BooleanFeature)feature;
            if (threshold != 0.0) {
                throw new IllegalArgumentException();
            }
            SimplePredicate leftPredicate = new SimplePredicate(feature.getName(), SimplePredicate.Operator.EQUAL).setValue(booleanFeature.getValue(0));
            SimplePredicate rightPredicate = new SimplePredicate(feature.getName(), SimplePredicate.Operator.EQUAL).setValue(booleanFeature.getValue(1));
            return new Predicate[]{leftPredicate, rightPredicate};
        }
        String value = ValueUtil.formatValue((Number)threshold);
        SimplePredicate leftPredicate = new SimplePredicate(feature.getName(), SimplePredicate.Operator.LESS_OR_EQUAL).setValue(value);
        SimplePredicate rightPredicate = new SimplePredicate(feature.getName(), SimplePredicate.Operator.GREATER_THAN).setValue(value);
        return new Predicate[]{leftPredicate, rightPredicate};
    }

    private static Predicate[] encodeCategoricalSplit(CategoricalSplit categoricalSplit, Schema schema) {
        Feature feature = schema.getFeature(categoricalSplit.featureIndex());
        double[] leftCategories = categoricalSplit.leftCategories();
        double[] rightCategories = categoricalSplit.rightCategories();
        if (feature instanceof ListFeature) {
            ListFeature listFeature = (ListFeature)feature;
            List values = listFeature.getValues();
            if (values.size() != leftCategories.length + rightCategories.length) {
                throw new IllegalArgumentException();
            }
            Predicate leftPredicate = TreeModelUtil.createCategoricalPredicate(listFeature, leftCategories);
            Predicate rightPredicate = TreeModelUtil.createCategoricalPredicate(listFeature, rightCategories);
            return new Predicate[]{leftPredicate, rightPredicate};
        }
        if (feature instanceof BinaryFeature) {
            SimplePredicate.Operator rightOperator;
            SimplePredicate.Operator leftOperator;
            BinaryFeature binaryFeature = (BinaryFeature)feature;
            if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                leftOperator = SimplePredicate.Operator.EQUAL;
                rightOperator = SimplePredicate.Operator.NOT_EQUAL;
            } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                rightOperator = SimplePredicate.Operator.EQUAL;
            } else {
                throw new IllegalArgumentException();
            }
            String value = ValueUtil.formatValue((Object)binaryFeature.getValue());
            SimplePredicate leftPredicate = new SimplePredicate(binaryFeature.getName(), leftOperator).setValue(value);
            SimplePredicate rightPredicate = new SimplePredicate(binaryFeature.getName(), rightOperator).setValue(value);
            return new Predicate[]{leftPredicate, rightPredicate};
        }
        throw new IllegalArgumentException();
    }

    private static Predicate createCategoricalPredicate(ListFeature listFeature, double[] categories) {
        ArrayList<String> values = new ArrayList<String>();
        for (int i = 0; i < categories.length; ++i) {
            int index = ValueUtil.asInt((Number)categories[i]);
            String value = listFeature.getValue(index);
            values.add(value);
        }
        if (values.size() == 1) {
            String value = (String)values.get(0);
            SimplePredicate simplePredicate = new SimplePredicate().setField(listFeature.getName()).setOperator(SimplePredicate.Operator.EQUAL).setValue(value);
            return simplePredicate;
        }
        Array array = new Array(Array.Type.INT, ValueUtil.formatArrayValue(values));
        SimpleSetPredicate simpleSetPredicate = new SimpleSetPredicate().setField(listFeature.getName()).setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN).setArray(array);
        return simpleSetPredicate;
    }
}

