/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Lists;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MissingValueStrategyType;
import org.dmg.pmml.NoTrueChildStrategyType;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.TreeModel;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MissingResultException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelManagerEvaluationContext;
import org.jpmml.evaluator.NodeClassificationMap;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PredicateUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

public class TreeModelEvaluator
extends ModelEvaluator<TreeModel>
implements HasEntityRegistry<Node> {
    private static final LoadingCache<TreeModel, BiMap<String, Node>> entityCache = CacheBuilder.newBuilder().weakKeys().build((CacheLoader)new CacheLoader<TreeModel, BiMap<String, Node>>(){

        public BiMap<String, Node> load(TreeModel treeModel) {
            HashBiMap result = HashBiMap.create();
            this.collectNodes(treeModel.getNode(), (BiMap<String, Node>)result);
            return result;
        }

        private void collectNodes(Node node, BiMap<String, Node> result) {
            EntityUtil.put(node, result);
            List children = node.getNodes();
            for (Node child : children) {
                this.collectNodes(child, result);
            }
        }
    });

    public TreeModelEvaluator(PMML pmml) {
        this(pmml, (TreeModel)TreeModelEvaluator.find((List)pmml.getModels(), TreeModel.class));
    }

    public TreeModelEvaluator(PMML pmml, TreeModel treeModel) {
        super(pmml, treeModel);
    }

    public String getSummary() {
        return "Tree model";
    }

    @Override
    public BiMap<String, Node> getEntityRegistry() {
        return this.getValue(entityCache);
    }

    @Override
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> arguments) {
        Node node;
        TreeModel treeModel = (TreeModel)this.getModel();
        if (!treeModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)treeModel);
        }
        ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this);
        context.pushFrame(arguments);
        MiningFunctionType miningFunction = treeModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: 
            case CLASSIFICATION: {
                node = this.evaluateTree(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)treeModel, (Enum)miningFunction);
            }
        }
        NodeClassificationMap values = null;
        if (node != null) {
            values = TreeModelEvaluator.createNodeClassificationMap(node);
        }
        Map<FieldName, ? extends ClassificationMap<?>> predictions = TargetUtil.evaluateClassification(values, context);
        return OutputUtil.evaluate(predictions, context);
    }

    private Node evaluateTree(ModelManagerEvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        Node root = treeModel.getNode();
        if (root == null) {
            throw new InvalidFeatureException((PMMLObject)treeModel);
        }
        LinkedList trail = Lists.newLinkedList();
        NodeResult result = new NodeResult(null);
        Boolean status = this.evaluateNode(root, context);
        if (status == null) {
            result = this.handleMissingValue(root, trail, context);
        } else if (status.booleanValue()) {
            result = this.handleTrue(root, trail, context);
        }
        if (result == null) {
            throw new MissingResultException((PMMLObject)root);
        }
        Node node = result.getNode();
        if (node != null || result.isFinal()) {
            return node;
        }
        NoTrueChildStrategyType noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
        switch (noTrueChildStrategy) {
            case RETURN_NULL_PREDICTION: {
                return null;
            }
            case RETURN_LAST_PREDICTION: {
                return this.lastPrediction(root, trail);
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)treeModel, (Enum)noTrueChildStrategy);
    }

    private NodeResult handleMissingValue(Node node, LinkedList<Node> trail, EvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        MissingValueStrategyType missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NULL_PREDICTION: {
                return new FinalNodeResult(null);
            }
            case LAST_PREDICTION: {
                return new FinalNodeResult(this.lastPrediction(node, trail));
            }
            case NONE: {
                return null;
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)treeModel, (Enum)missingValueStrategy);
    }

    private NodeResult handleTrue(Node node, LinkedList<Node> trail, EvaluationContext context) {
        List children = node.getNodes();
        if (children.isEmpty()) {
            return new NodeResult(node);
        }
        trail.add(node);
        for (Node child : children) {
            Boolean status = this.evaluateNode(child, context);
            if (status == null) {
                NodeResult result = this.handleMissingValue(child, trail, context);
                if (result == null) continue;
                return result;
            }
            if (!status.booleanValue()) continue;
            return this.handleTrue(child, trail, context);
        }
        return new NodeResult(null);
    }

    private Node lastPrediction(Node node, LinkedList<Node> trail) {
        try {
            return trail.getLast();
        }
        catch (NoSuchElementException nsee) {
            throw new MissingResultException((PMMLObject)node);
        }
    }

    private Boolean evaluateNode(Node node, EvaluationContext context) {
        Predicate predicate = node.getPredicate();
        if (predicate == null) {
            throw new InvalidFeatureException((PMMLObject)node);
        }
        EmbeddedModel embeddedModel = node.getEmbeddedModel();
        if (embeddedModel != null) {
            throw new UnsupportedFeatureException((PMMLObject)embeddedModel);
        }
        return PredicateUtil.evaluate(predicate, context);
    }

    private static NodeClassificationMap createNodeClassificationMap(Node node) {
        NodeClassificationMap result = new NodeClassificationMap(node);
        List scoreDistributions = node.getScoreDistributions();
        double sum = 0.0;
        for (ScoreDistribution scoreDistribution : scoreDistributions) {
            sum += scoreDistribution.getRecordCount();
        }
        for (ScoreDistribution scoreDistribution : scoreDistributions) {
            Double value = scoreDistribution.getProbability();
            if (value == null) {
                value = scoreDistribution.getRecordCount() / sum;
            }
            result.put(scoreDistribution.getValue(), value);
        }
        return result;
    }

    private static class FinalNodeResult
    extends NodeResult {
        public FinalNodeResult(Node node) {
            super(node);
        }

        @Override
        public boolean isFinal() {
            return true;
        }
    }

    private static class NodeResult {
        private Node node = null;

        public NodeResult(Node node) {
            this.setNode(node);
        }

        public boolean isFinal() {
            return false;
        }

        public Node getNode() {
            return this.node;
        }

        private void setNode(Node node) {
            this.node = node;
        }
    }
}

