/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataType;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MissingValueStrategyType;
import org.dmg.pmml.NoTrueChildStrategyType;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.TreeModel;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NodeScore;
import org.jpmml.evaluator.NodeScoreDistribution;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PredicateUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.TypeUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;

public class TreeModelEvaluator
extends ModelEvaluator<TreeModel>
implements HasEntityRegistry<Node> {
    private transient BiMap<String, Node> entityRegistry = null;
    private static final LoadingCache<TreeModel, BiMap<String, Node>> entityCache = CacheUtil.buildLoadingCache(new CacheLoader<TreeModel, BiMap<String, Node>>(){

        public BiMap<String, Node> load(TreeModel treeModel) {
            ImmutableBiMap.Builder<String, Node> builder = new ImmutableBiMap.Builder<String, Node>();
            builder = this.collectNodes(treeModel.getNode(), new AtomicInteger(1), builder);
            return builder.build();
        }

        private ImmutableBiMap.Builder<String, Node> collectNodes(Node node, AtomicInteger index, ImmutableBiMap.Builder<String, Node> builder) {
            builder = EntityUtil.put(node, index, builder);
            if (!node.hasNodes()) {
                return builder;
            }
            List children = node.getNodes();
            for (Node child : children) {
                builder = this.collectNodes(child, index, builder);
            }
            return builder;
        }
    });

    public TreeModelEvaluator(PMML pmml) {
        super(pmml, TreeModel.class);
    }

    public TreeModelEvaluator(PMML pmml, TreeModel treeModel) {
        super(pmml, treeModel);
    }

    @Override
    public String getSummary() {
        return "Tree model";
    }

    @Override
    public BiMap<String, Node> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = this.getValue(entityCache);
        }
        return this.entityRegistry;
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, Object> predictions;
        TreeModel treeModel = (TreeModel)this.getModel();
        if (!treeModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)treeModel);
        }
        MiningFunctionType miningFunction = treeModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)treeModel, (Enum<?>)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) {
        Trail trail = new Trail();
        Node node = this.evaluateTree(trail, context);
        if (node == null) {
            return TargetUtil.evaluateRegressionDefault(context);
        }
        Double score = (Double)TypeUtil.parseOrCast(DataType.DOUBLE, node.getScore());
        FieldName targetField = this.getTargetField();
        NodeScore nodeScore = this.createNodeScore(node, TargetUtil.evaluateRegressionInternal(targetField, score, context));
        return Collections.singletonMap(targetField, nodeScore);
    }

    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        Trail trail = new Trail();
        Node node = this.evaluateTree(trail, context);
        if (node == null) {
            return TargetUtil.evaluateClassificationDefault(context);
        }
        double missingValuePenalty = 1.0;
        int missingLevels = trail.getMissingLevels();
        if (missingLevels > 0) {
            missingValuePenalty = Math.pow(treeModel.getMissingValuePenalty(), missingLevels);
        }
        NodeScoreDistribution result = this.createNodeScoreDistribution(node, missingValuePenalty);
        return TargetUtil.evaluateClassification(result, context);
    }

    private Node evaluateTree(Trail trail, EvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        Node root = treeModel.getNode();
        if (root == null) {
            throw new InvalidFeatureException((PMMLObject)treeModel);
        }
        Boolean status = this.evaluateNode(trail, root, context);
        if (status != null && status.booleanValue()) {
            String score;
            Node node = (trail = this.handleTrue(trail, root, context)).getResult();
            if (node != null && (score = node.getScore()) == null) {
                throw new InvalidFeatureException((PMMLObject)node);
            }
            return node;
        }
        return null;
    }

    private Boolean evaluateNode(Trail trail, Node node, EvaluationContext context) {
        EmbeddedModel embeddedModel = node.getEmbeddedModel();
        if (embeddedModel != null) {
            throw new UnsupportedFeatureException((PMMLObject)embeddedModel);
        }
        Predicate predicate = node.getPredicate();
        if (predicate == null) {
            throw new InvalidFeatureException((PMMLObject)node);
        }
        if (predicate instanceof CompoundPredicate) {
            CompoundPredicate compoundPredicate = (CompoundPredicate)predicate;
            PredicateUtil.CompoundPredicateResult result = PredicateUtil.evaluateCompoundPredicateInternal(compoundPredicate, context);
            if (result.isAlternative()) {
                trail.addMissingLevel();
            }
            return result.getResult();
        }
        return PredicateUtil.evaluate(predicate, context);
    }

    private Trail handleTrue(Trail trail, Node node, EvaluationContext context) {
        if (!node.hasNodes()) {
            return trail.selectNode(node);
        }
        trail.push(node);
        List children = node.getNodes();
        int max = children.size();
        for (int i = 0; i < max; ++i) {
            Node child = (Node)children.get(i);
            Boolean status = this.evaluateNode(trail, child, context);
            if (status == null) {
                Trail destination = this.handleMissingValue(trail, node, child, context);
                if (destination == null) continue;
                return destination;
            }
            if (!status.booleanValue()) continue;
            return this.handleTrue(trail, child, context);
        }
        return this.handleNoTrueChild(trail);
    }

    private Trail handleDefaultChild(Trail trail, Node node, EvaluationContext context) {
        String defaultChild = node.getDefaultChild();
        if (defaultChild == null) {
            throw new InvalidFeatureException((PMMLObject)node);
        }
        trail.addMissingLevel();
        List children = node.getNodes();
        int max = children.size();
        for (int i = 0; i < max; ++i) {
            Node child = (Node)children.get(i);
            String id = child.getId();
            if (id == null || !id.equals(defaultChild)) continue;
            return this.handleTrue(trail, child, context);
        }
        throw new InvalidFeatureException((PMMLObject)node);
    }

    private Trail handleNoTrueChild(Trail trail) {
        TreeModel treeModel = (TreeModel)this.getModel();
        NoTrueChildStrategyType noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
        switch (noTrueChildStrategy) {
            case RETURN_NULL_PREDICTION: {
                return trail.selectNull();
            }
            case RETURN_LAST_PREDICTION: {
                Node lastPrediction = trail.getLastPrediction();
                if (lastPrediction.getScore() != null) {
                    return trail.selectLastPrediction();
                }
                return trail.selectNull();
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)treeModel, (Enum<?>)noTrueChildStrategy);
    }

    private Trail handleMissingValue(Trail trail, Node parent, Node node, EvaluationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        MissingValueStrategyType missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NULL_PREDICTION: {
                return trail.selectNull();
            }
            case LAST_PREDICTION: {
                return trail.selectLastPrediction();
            }
            case DEFAULT_CHILD: {
                return this.handleDefaultChild(trail, parent, context);
            }
            case NONE: {
                return null;
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)treeModel, (Enum<?>)missingValueStrategy);
    }

    private NodeScore createNodeScore(Node node, Object value) {
        BiMap<String, Node> entityRegistry = this.getEntityRegistry();
        NodeScore result = new NodeScore(entityRegistry, node, value);
        return result;
    }

    private NodeScoreDistribution createNodeScoreDistribution(Node node, double missingValuePenalty) {
        ScoreDistribution scoreDistribution;
        int i;
        BiMap<String, Node> entityRegistry = this.getEntityRegistry();
        NodeScoreDistribution result = new NodeScoreDistribution(entityRegistry, node);
        if (!node.hasScoreDistributions()) {
            return result;
        }
        List scoreDistributions = node.getScoreDistributions();
        double sum = 0.0;
        int max = scoreDistributions.size();
        for (i = 0; i < max; ++i) {
            scoreDistribution = (ScoreDistribution)scoreDistributions.get(i);
            Double recordCount = scoreDistribution.getRecordCount();
            if (recordCount == null) {
                throw new InvalidFeatureException((PMMLObject)scoreDistribution);
            }
            sum += recordCount.doubleValue();
        }
        max = scoreDistributions.size();
        for (i = 0; i < max; ++i) {
            scoreDistribution = (ScoreDistribution)scoreDistributions.get(i);
            Double probability = scoreDistribution.getProbability();
            if (probability == null) {
                Double recordCount = scoreDistribution.getRecordCount();
                probability = recordCount / sum;
            }
            result.put(scoreDistribution.getValue(), probability);
            Double confidence = scoreDistribution.getConfidence();
            if (confidence == null) continue;
            result.putConfidence(scoreDistribution.getValue(), confidence * missingValuePenalty);
        }
        return result;
    }

    private static class Trail {
        private Node lastPrediction = null;
        private Node result = null;
        private int missingLevels = 0;

        public void push(Node node) {
            this.setLastPrediction(node);
        }

        public Trail selectNull() {
            this.setResult(null);
            return this;
        }

        public Trail selectNode(Node node) {
            this.setResult(node);
            return this;
        }

        public Trail selectLastPrediction() {
            this.setResult(this.getLastPrediction());
            return this;
        }

        public Node getResult() {
            return this.result;
        }

        private void setResult(Node result) {
            this.result = result;
        }

        public Node getLastPrediction() {
            if (this.lastPrediction == null) {
                throw new EvaluationException();
            }
            return this.lastPrediction;
        }

        private void setLastPrediction(Node lastPrediction) {
            this.lastPrediction = lastPrediction;
        }

        public void addMissingLevel() {
            this.setMissingLevels(this.getMissingLevels() + 1);
        }

        public int getMissingLevels() {
            return this.missingLevels;
        }

        private void setMissingLevels(int missingLevels) {
            this.missingLevels = missingLevels;
        }
    }
}

