/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.translator.mining;

import com.google.common.collect.Iterables;
import com.sun.codemodel.JAssignmentTarget;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JStatement;
import com.sun.codemodel.JVar;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.False;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.ComplexNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ValueUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.model.PMMLObjectKey;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.model.visitors.NodeFilterer;
import org.jpmml.translator.FieldInfoMap;
import org.jpmml.translator.JCompoundAssignment;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueBuilder;
import org.jpmml.translator.mining.MiningModelTranslator;
import org.jpmml.translator.tree.NodeGroup;
import org.jpmml.translator.tree.NodeGroupUtil;
import org.jpmml.translator.tree.Scorer;
import org.jpmml.translator.tree.TreeModelTranslator;

public class TreeModelBoosterTranslator
extends MiningModelTranslator {
    public static final int NODE_COUNT_LIMIT = Integer.getInteger(TreeModelBoosterTranslator.class.getName() + "#NODE_COUNT_LIMIT", 1000);

    public TreeModelBoosterTranslator(PMML pmml, MiningModel miningModel) {
        super(pmml, miningModel);
        MiningFunction miningFunction = miningModel.requireMiningFunction();
        switch (miningFunction) {
            case REGRESSION: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)miningModel, (Enum)miningFunction);
            }
        }
        MathContext mathContext = miningModel.getMathContext();
        Segmentation segmentation = miningModel.requireSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.requireMultipleModelMethod();
        switch (multipleModelMethod) {
            case SUM: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)segmentation, (Enum)multipleModelMethod);
            }
        }
        List segments = segmentation.requireSegments();
        block12: for (Segment segment : segments) {
            True _true = (True)segment.requirePredicate(True.class);
            TreeModel treeModel = (TreeModel)segment.requireModel(TreeModel.class);
            if (treeModel.getMathContext() != mathContext) {
                throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)treeModel.getMathContext());
            }
            TreeModelBoosterTranslator.checkMiningSchema((Model)treeModel);
            TreeModelBoosterTranslator.checkTargets((Model)treeModel);
            TreeModelBoosterTranslator.checkOutput((Model)treeModel);
            TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
            switch (noTrueChildStrategy) {
                case RETURN_LAST_PREDICTION: {
                    break;
                }
                default: {
                    throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)noTrueChildStrategy);
                }
            }
            TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
            switch (missingValueStrategy) {
                case NONE: {
                    continue block12;
                }
            }
            throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)missingValueStrategy);
        }
        final AtomicInteger nodeCount = new AtomicInteger(0);
        AbstractVisitor nodeCounter = new AbstractVisitor(){

            public VisitorAction visit(Node node) {
                nodeCount.incrementAndGet();
                return super.visit(node);
            }
        };
        nodeCounter.applyTo((Visitable)segmentation);
        if (nodeCount.get() > NODE_COUNT_LIMIT) {
            throw new UnsupportedElementException((PMMLObject)segmentation);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public JMethod translateRegressor(TranslationContext context) {
        PMML pmml = this.getPMML();
        final MiningModel miningModel = (MiningModel)this.getModel();
        TreeModel treeModel = TreeModelBoosterTranslator.transformModel(miningModel);
        final MathContext mathContext = treeModel.getMathContext();
        Targets targets = treeModel.getTargets();
        Target target = (Target)Iterables.getOnlyElement((Iterable)targets);
        Number intercept = TreeModelBoosterTranslator.extractIntercept(target);
        if (intercept == null) {
            intercept = 0;
        }
        TreeModelTranslator modelTranslator = new TreeModelTranslator(pmml, treeModel);
        Node root = treeModel.getNode();
        FieldInfoMap fieldInfos = ((ModelTranslator)modelTranslator).getFieldInfos(Collections.singleton(root));
        JMethod evaluateNodeMethod = TreeModelBoosterTranslator.createEvaluatorMethod(Value.class, (PMMLObject)root, true, context);
        try {
            JVar resultVar;
            context.pushScope(new MethodScope(evaluateNodeMethod));
            switch (mathContext) {
                case FLOAT: {
                    resultVar = context.declare(Float.TYPE, "result", JExpr.lit((float)intercept.floatValue()));
                    break;
                }
                case DOUBLE: {
                    resultVar = context.declare(Double.TYPE, "result", JExpr.lit((double)intercept.doubleValue()));
                    break;
                }
                default: {
                    throw new UnsupportedAttributeException((PMMLObject)miningModel, (Enum)mathContext);
                }
            }
            Scorer<Number> scorer = new Scorer<Number>(){

                @Override
                public Number prepare(Node node) {
                    Object score = node.requireScore();
                    return (Number)score;
                }

                @Override
                public void yield(Number score, TranslationContext context) {
                    JBlock block = context.block();
                    if (score != null && score.doubleValue() != 0.0) {
                        block.add(this.createCompoundAssignment(resultVar, score));
                    }
                }

                @Override
                public void yieldIf(JExpression expression, Number score, TranslationContext context) {
                    JBlock block = context.block();
                    if (score != null && score.doubleValue() != 0.0) {
                        JBlock thenBlock = block._if(expression)._then();
                        thenBlock.add(this.createCompoundAssignment(resultVar, score));
                    }
                }

                private JStatement createCompoundAssignment(JVar resultVar2, Number value) {
                    switch (mathContext) {
                        case FLOAT: {
                            float floatValue = value.floatValue();
                            return new JCompoundAssignment((JAssignmentTarget)resultVar2, JExpr.lit((float)Math.abs(floatValue)), floatValue >= 0.0f ? "+=" : "-=");
                        }
                        case DOUBLE: {
                            double doubleValue = value.doubleValue();
                            return new JCompoundAssignment((JAssignmentTarget)resultVar2, JExpr.lit((double)Math.abs(doubleValue)), doubleValue >= 0.0 ? "+=" : "-=");
                        }
                    }
                    throw new UnsupportedAttributeException((PMMLObject)miningModel, (Enum)mathContext);
                }
            };
            TreeModelTranslator.translateNode(treeModel, root, scorer, fieldInfos, context);
            ValueBuilder valueBuilder = new ValueBuilder(context).declare("resultValue", context.getValueFactoryVariable().newValue((JExpression)resultVar));
            TreeModelBoosterTranslator.translateRegressorTarget((Model)treeModel, target, valueBuilder);
            context._return((JExpression)valueBuilder.getVariable());
        }
        finally {
            context.popScope();
        }
        return evaluateNodeMethod;
    }

    private static TreeModel transformModel(MiningModel miningModel) {
        TreeModel treeModel = TreeModelBoosterTranslator.transformSegmentation(miningModel);
        Segmentation segmentation = miningModel.requireSegmentation();
        List segments = segmentation.requireSegments();
        if (!segments.isEmpty()) {
            segments.clear();
        }
        Segment segment = new Segment().setPredicate((Predicate)True.INSTANCE).setModel((Model)treeModel);
        LocalTransformations localTransformations = treeModel.getLocalTransformations();
        if (localTransformations != null) {
            miningModel.setLocalTransformations(localTransformations);
            treeModel.setLocalTransformations(null);
        }
        segments.add(segment);
        return treeModel;
    }

    private static TreeModel transformSegmentation(final MiningModel miningModel) {
        Number zero;
        final MathContext mathContext = miningModel.getMathContext();
        Segmentation segmentation = miningModel.requireSegmentation();
        switch (mathContext) {
            case FLOAT: {
                zero = Float.valueOf(0.0f);
                break;
            }
            case DOUBLE: {
                zero = 0.0;
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)miningModel, (Enum)mathContext);
            }
        }
        ComplexNode root = new ComplexNode().setScore((Object)zero).setPredicate((Predicate)True.INSTANCE);
        final Target target = new Target().setRescaleConstant(zero);
        Targets targets = new Targets().addTargets(new Target[]{target});
        final TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, new MiningSchema(), (Node)root).setMathContext(mathContext).setNoTrueChildStrategy(TreeModel.NoTrueChildStrategy.RETURN_LAST_PREDICTION).setMissingValueStrategy(TreeModel.MissingValueStrategy.NONE).setTargets(targets);
        NodeFilterer nodeFilterer = new NodeFilterer(){

            public ComplexNode filter(Node node) {
                if (node instanceof ComplexNode) {
                    ComplexNode complexNode = (ComplexNode)node;
                    return complexNode;
                }
                return new ComplexNode(node);
            }
        };
        nodeFilterer.applyTo((Visitable)segmentation);
        AbstractVisitor nodeExtender = new AbstractVisitor(){

            public VisitorAction visit(Node node) {
                PMMLObject parent = this.getParent();
                NodeGroupUtil.setParentId(node, System.identityHashCode(parent));
                return super.visit(node);
            }
        };
        nodeExtender.applyTo((Visitable)segmentation);
        AbstractVisitor nodeScoreUpdater = new AbstractVisitor(){

            public VisitorAction visit(TreeModel treeModel) {
                Node root = treeModel.getNode();
                Number score = (Number)root.requireScore();
                target.setRescaleConstant(ValueUtil.add((MathContext)mathContext, (Number)target.getRescaleConstant(), (Number)score));
                this.updateNodeScores(root, score);
                return super.visit(treeModel);
            }

            private void updateNodeScores(Node node, Number adjustment) {
                Number score = (Number)node.requireScore();
                node.setScore((Object)ValueUtil.subtract((MathContext)mathContext, (Number)score, (Number)adjustment));
                if (node.hasNodes()) {
                    List children = node.getNodes();
                    for (Node child : children) {
                        this.updateNodeScores(child, adjustment);
                    }
                }
            }
        };
        nodeScoreUpdater.applyTo((Visitable)segmentation);
        AbstractVisitor treeModelInitializer = new AbstractVisitor(){
            {
                MiningSchema miningSchema = miningModel.requireMiningSchema();
                LocalTransformations localTransformations = miningModel.getLocalTransformations();
                if (miningSchema.hasMiningFields()) {
                    this.addMiningFields(miningSchema.getMiningFields());
                }
                if (localTransformations != null && localTransformations.hasDerivedFields()) {
                    this.addDerivedFields(localTransformations.getDerivedFields());
                }
            }

            public VisitorAction visit(TreeModel treeModel2) {
                Node root = treeModel2.getNode();
                True _true = (True)root.requirePredicate(True.class);
                if (root.hasNodes()) {
                    List children = root.getNodes();
                    for (Node child : children) {
                        this.addNode(child);
                    }
                    Number score = (Number)root.requireScore();
                    if (score.doubleValue() != 0.0) {
                        ComplexNode elseChild = new ComplexNode().setScore((Object)score).setPredicate((Predicate)True.INSTANCE);
                        NodeGroupUtil.setParentId((Node)elseChild, System.identityHashCode(root));
                        this.addNode((Node)elseChild);
                    }
                }
                return super.visit(treeModel2);
            }

            public VisitorAction visit(LocalTransformations localTransformations) {
                if (localTransformations.hasDerivedFields()) {
                    this.addDerivedFields(localTransformations.getDerivedFields());
                }
                return super.visit(localTransformations);
            }

            private void addMiningFields(List<MiningField> miningFields) {
                MiningSchema miningSchema = treeModel.requireMiningSchema();
                block3: for (MiningField miningField : miningFields) {
                    MiningField.UsageType usageType = miningField.getUsageType();
                    switch (usageType) {
                        case ACTIVE: {
                            break;
                        }
                        default: {
                            continue block3;
                        }
                    }
                    miningSchema.addMiningFields(new MiningField[]{miningField});
                }
            }

            private void addDerivedFields(List<DerivedField> derivedFields) {
                LocalTransformations localTransformations = treeModel.getLocalTransformations();
                if (localTransformations == null) {
                    localTransformations = new LocalTransformations();
                    treeModel.setLocalTransformations(localTransformations);
                }
                for (DerivedField derivedField : derivedFields) {
                    localTransformations.addDerivedFields(new DerivedField[]{derivedField});
                }
            }

            private void addNode(Node node) {
                Node root = treeModel.getNode();
                root.addNodes(node);
            }
        };
        treeModelInitializer.applyTo((Visitable)segmentation);
        AbstractVisitor nodeGroupMerger = new AbstractVisitor(){

            public VisitorAction visit(Node node) {
                List<NodeGroup> nodeGroups;
                List children;
                if (node.hasNodes() && (children = node.getNodes()).size() != 1 && (nodeGroups = NodeGroupUtil.group(children)).size() > 1) {
                    HashMap<List<PMMLObjectKey>, NodeGroup> uniqueNodeGroups = null;
                    for (NodeGroup nodeGroup : nodeGroups) {
                        List<PMMLObjectKey> key;
                        NodeGroup prevNodeGroup;
                        if (!nodeGroup.isShallow()) continue;
                        if (uniqueNodeGroups == null) {
                            uniqueNodeGroups = new HashMap<List<PMMLObjectKey>, NodeGroup>();
                        }
                        if ((prevNodeGroup = (NodeGroup)uniqueNodeGroups.get(key = this.createKey(nodeGroup))) != null) {
                            this.merge(prevNodeGroup, nodeGroup);
                            children.removeAll(nodeGroup);
                            continue;
                        }
                        uniqueNodeGroups.put(key, nodeGroup);
                    }
                }
                return super.visit(node);
            }

            private List<PMMLObjectKey> createKey(List<Node> nodes) {
                return nodes.stream().map(node -> new PMMLObjectKey((PMMLObject)this.filterPredicate(node.requirePredicate()))).collect(Collectors.toList());
            }

            private Predicate filterPredicate(Predicate predicate) {
                if (predicate instanceof True) {
                    return True.INSTANCE;
                }
                if (predicate instanceof False) {
                    return False.INSTANCE;
                }
                return predicate;
            }

            private void merge(List<Node> leftNodes, List<Node> rightNodes) {
                for (int i = 0; i < leftNodes.size(); ++i) {
                    Node leftNode = leftNodes.get(i);
                    Node rightNode = rightNodes.get(i);
                    if (leftNode.hasNodes() || rightNode.hasNodes()) {
                        throw new IllegalArgumentException();
                    }
                    leftNode.setScore((Object)ValueUtil.add((MathContext)mathContext, (Number)((Number)leftNode.requireScore()), (Number)((Number)rightNode.requireScore())));
                }
            }
        };
        nodeGroupMerger.applyTo((Visitable)treeModel);
        return treeModel;
    }
}

