/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.tree;

import chaid.Column;
import chaid.Split;
import com.google.common.math.DoubleMath;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataType;
import org.dmg.pmml.False;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreFrequency;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import treelib.Node;
import treelib.Tree;

public class CHAIDUtil {
    private CHAIDUtil() {
    }

    public static TreeModel encodeModel(MiningFunction miningFunction, Tree tree, Schema schema) {
        org.dmg.pmml.tree.Node root = CHAIDUtil.encodeNode((Predicate)True.INSTANCE, tree.selectRoot(), tree, new PredicateManager(), schema);
        Output output = new Output().addOutputFields(new OutputField[]{ModelUtil.createEntityIdField((String)"nodeId", (DataType)DataType.INTEGER)});
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setOutput(output);
        return treeModel;
    }

    private static org.dmg.pmml.tree.Node encodeNode(Predicate predicate, Node node, Tree tree, PredicateManager predicateManager, Schema schema) {
        Object result;
        Label label = schema.getLabel();
        chaid.Node tag = (chaid.Node)((Object)node.getTag(chaid.Node.class));
        List successors = node.selectSuccessors(tree);
        Column depV = tag.getDepV();
        List<Integer> indices = tag.getIndices();
        Split split = tag.getSplit();
        List<? extends Number> depVArr = depV.getArr();
        ClassDictUtil.checkSize((Collection[])new Collection[]{depVArr, indices});
        Integer columnId = split.getColumnId();
        List<List<Integer>> splits = split.getSplits();
        List<List<?>> splitMap = split.getSplitMap();
        ClassDictUtil.checkSize((Collection[])new Collection[]{successors, splits, splitMap});
        Comparator<Node> comparator = new Comparator<Node>(){

            @Override
            public int compare(Node left, Node right) {
                chaid.Node leftTag = (chaid.Node)((Object)left.getTag(chaid.Node.class));
                chaid.Node rightTag = (chaid.Node)((Object)right.getTag(chaid.Node.class));
                List<Integer> leftIndices = leftTag.getIndices();
                List<Integer> rightIndices = rightTag.getIndices();
                return Integer.compare(leftIndices.size(), rightIndices.size());
            }
        };
        if (!successors.isEmpty()) {
            Node successor;
            CategoricalFeature categoricalFeature = (CategoricalFeature)schema.getFeature(columnId.intValue());
            List categories = categoricalFeature.getValues();
            result = label instanceof CategoricalLabel ? new ClassifierNode(null, predicate) : new CountingBranchNode(null, predicate);
            LinkedHashSet unusedValues = new LinkedHashSet(categories);
            for (int i = 0; i < successors.size(); ++i) {
                List<Integer> splitIndices = splits.get(i);
                List<?> splitValues = splitMap.get(i);
                ClassDictUtil.checkSize((Collection[])new Collection[]{splitIndices, splitValues});
                for (int j = 0; j < splitIndices.size(); ++j) {
                    Integer splitIndex = splitIndices.get(j);
                    Object splitValue = splitValues.get(j);
                    if (splitIndex == -1) continue;
                    CHAIDUtil.removeCategory(unusedValues, splitValue);
                }
            }
            Node maxSuccessor = null;
            if (!unusedValues.isEmpty()) {
                for (int i = 0; i < successors.size(); ++i) {
                    successor = (Node)successors.get(i);
                    if (maxSuccessor != null && comparator.compare(successor, maxSuccessor) < 0) continue;
                    maxSuccessor = successor;
                }
            }
            for (int i = 0; i < successors.size(); ++i) {
                False successorPredicate;
                successor = (Node)successors.get(i);
                List<Integer> list = splits.get(i);
                List<?> splitValues = splitMap.get(i);
                ArrayList<Object> values = new ArrayList<Object>();
                boolean withMissing = false;
                for (int j = 0; j < list.size(); ++j) {
                    Integer splitIndex = list.get(j);
                    Object splitValue = splitValues.get(j);
                    if (splitIndex == -1) {
                        withMissing = true;
                        continue;
                    }
                    Object value = CHAIDUtil.selectCategory(categories, splitValue);
                    values.add(value);
                }
                if (Objects.equals(successor, maxSuccessor)) {
                    values.addAll(unusedValues);
                }
                if (!values.isEmpty()) {
                    successorPredicate = predicateManager.createPredicate((Feature)categoricalFeature, values);
                    if (withMissing) {
                        successorPredicate = predicateManager.createCompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, new Predicate[]{successorPredicate, predicateManager.createSimplePredicate((Feature)categoricalFeature, SimplePredicate.Operator.IS_MISSING, null)});
                    }
                } else {
                    successorPredicate = False.INSTANCE;
                    if (withMissing) {
                        successorPredicate = predicateManager.createSimplePredicate((Feature)categoricalFeature, SimplePredicate.Operator.IS_MISSING, null);
                    }
                }
                result.addNodes(CHAIDUtil.encodeNode((Predicate)successorPredicate, successor, tree, predicateManager, schema));
            }
        } else {
            result = label instanceof CategoricalLabel ? new ClassifierNode(null, predicate) : new CountingLeafNode(null, predicate);
        }
        result.setId((Object)node.getIdentifier()).setRecordCount((Number)depVArr.size());
        if (label instanceof ContinuousLabel) {
            ContinuousLabel continuousLabel = (ContinuousLabel)label;
            Double score = DoubleMath.mean(depVArr);
            result.setScore((Object)score);
        } else if (label instanceof CategoricalLabel) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)label;
            Map countMap = depVArr.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            List scoreDistributions = result.getScoreDistributions();
            Long maxCount = null;
            Set entries = countMap.entrySet();
            for (Map.Entry entry : entries) {
                Object value = categoricalLabel.getValue(((Integer)entry.getKey()).intValue());
                Long count = (Long)entry.getValue();
                if (maxCount == null || maxCount.compareTo(count) < 0) {
                    maxCount = count;
                    result.setScore(value);
                }
                ScoreFrequency scoreDistribution = new ScoreFrequency(value, (Number)count);
                scoreDistributions.add(scoreDistribution);
            }
        } else {
            throw new IllegalArgumentException();
        }
        return result;
    }

    private static void removeCategory(Collection<?> values, Object splitValue) {
        Iterator<?> it = values.iterator();
        while (it.hasNext()) {
            Object value = it.next();
            boolean matches = CHAIDUtil.equals(value, splitValue);
            if (!matches) continue;
            it.remove();
            return;
        }
        throw new IllegalArgumentException();
    }

    private static Object selectCategory(Collection<?> values, Object splitValue) {
        for (Object value : values) {
            boolean matches = CHAIDUtil.equals(value, splitValue);
            if (!matches) continue;
            return value;
        }
        throw new IllegalArgumentException();
    }

    private static boolean equals(Object left, Object right) {
        if (left instanceof Number && right instanceof Number) {
            return Double.compare(((Number)left).doubleValue(), ((Number)right).doubleValue()) == 0;
        }
        return Objects.equals(left, right);
    }
}

