/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.translator.tree;

import com.google.common.collect.ArrayListMultimap;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JInvocation;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JType;
import com.sun.codemodel.JVar;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.dmg.pmml.ComplexArray;
import org.dmg.pmml.DataType;
import org.dmg.pmml.False;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.TextIndex;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.PMMLAttributes;
import org.dmg.pmml.tree.PMMLElements;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.translator.ArrayManager;
import org.jpmml.translator.Encoder;
import org.jpmml.translator.FieldInfo;
import org.jpmml.translator.FpPrimitiveEncoder;
import org.jpmml.translator.FunctionInvocation;
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.JBinaryFileInitializer;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.OperableRef;
import org.jpmml.translator.OrdinalEncoder;
import org.jpmml.translator.PMMLObjectUtil;
import org.jpmml.translator.Scope;
import org.jpmml.translator.TermFrequencyEncoder;
import org.jpmml.translator.TextIndexUtil;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueFactoryRef;
import org.jpmml.translator.ValueMapBuilder;
import org.jpmml.translator.tree.CountingActiveFieldFinder;
import org.jpmml.translator.tree.DiscreteValueFinder;
import org.jpmml.translator.tree.NodeScoreDistributionManager;
import org.jpmml.translator.tree.NodeScoreManager;

public class TreeModelTranslator
extends ModelTranslator<TreeModel> {
    public static final JExpression NULL_RESULT = JExpr.lit((int)-1);

    public TreeModelTranslator(PMML pmml, TreeModel treeModel) {
        super(pmml, treeModel);
        TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NONE: 
            case NULL_PREDICTION: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)missingValueStrategy);
            }
        }
        TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
        switch (noTrueChildStrategy) {
            case RETURN_LAST_PREDICTION: 
            case RETURN_NULL_PREDICTION: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)noTrueChildStrategy);
            }
        }
        Node root = treeModel.getNode();
        if (root == null) {
            throw new MissingElementException((PMMLObject)treeModel, PMMLElements.TREEMODEL_NODE);
        }
        Predicate predicate = root.getPredicate();
        if (!(predicate instanceof True)) {
            throw new UnsupportedElementException((PMMLObject)predicate);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public JMethod translateRegressor(TranslationContext context) {
        TreeModel treeModel = (TreeModel)this.getModel();
        Node node = treeModel.getNode();
        final JDefinedClass owner = context.getOwner();
        NodeScoreManager scoreManager = new NodeScoreManager((JType)context.ref(Number.class), IdentifierUtil.create("scores", (PMMLObject)node)){
            {
                super(componentType, name);
                this.initArrayVar(owner);
                this.initArray();
            }
        };
        Map<FieldName, FieldInfo> fieldInfos = this.getFieldInfos(Collections.singleton(node));
        JMethod evaluateNodeMethod = TreeModelTranslator.createEvaluatorMethod(Integer.TYPE, (PMMLObject)node, false, context);
        try {
            context.pushScope(new MethodScope(evaluateNodeMethod));
            TreeModelTranslator.translateNode(treeModel, node, scoreManager, fieldInfos, context);
        }
        finally {
            context.popScope();
        }
        JMethod evaluateTreeModelMethod = TreeModelTranslator.createEvaluatorMethod(Number.class, (PMMLObject)treeModel, false, context);
        try {
            context.pushScope(new MethodScope(evaluateTreeModelMethod));
            JVar indexVar = context.declare(Integer.TYPE, "index", (JExpression)TreeModelTranslator.createEvaluatorMethodInvocation(evaluateNodeMethod, context));
            context._returnIf(indexVar.eq(NULL_RESULT), JExpr._null());
            context._return(scoreManager.getComponent((JExpression)indexVar));
        }
        finally {
            context.popScope();
        }
        return evaluateTreeModelMethod;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public JMethod translateClassifier(TranslationContext context) {
        final TreeModel treeModel = (TreeModel)this.getModel();
        Node node = treeModel.getNode();
        Object[] categories = this.getTargetCategories();
        final JDefinedClass owner = context.getOwner();
        NodeScoreDistributionManager<Number> scoreManager = new NodeScoreDistributionManager<Number>((JType)context.ref(Number[].class), IdentifierUtil.create("scores", (PMMLObject)node), categories){
            private ValueFactory<Number> valueFactory;
            {
                super(componentType, name, categories);
                this.valueFactory = ModelTranslator.getValueFactory((Model)treeModel);
                this.initArrayVar(owner);
                this.initArray();
            }

            @Override
            public ValueFactory<Number> getValueFactory() {
                return this.valueFactory;
            }
        };
        Map<FieldName, FieldInfo> fieldInfos = this.getFieldInfos(Collections.singleton(node));
        JMethod evaluateNodeMethod = TreeModelTranslator.createEvaluatorMethod(Integer.TYPE, (PMMLObject)node, false, context);
        try {
            context.pushScope(new MethodScope(evaluateNodeMethod));
            TreeModelTranslator.translateNode(treeModel, node, scoreManager, fieldInfos, context);
        }
        finally {
            context.popScope();
        }
        JMethod evaluateTreeModelMethod = TreeModelTranslator.createEvaluatorMethod(Classification.class, (PMMLObject)treeModel, true, context);
        try {
            context.pushScope(new MethodScope(evaluateTreeModelMethod));
            JVar indexVar = context.declare(Integer.TYPE, "index", (JExpression)TreeModelTranslator.createEvaluatorMethodInvocation(evaluateNodeMethod, context));
            context._returnIf(indexVar.eq(NULL_RESULT), JExpr._null());
            JVar scoreVar = context.declare(Number[].class, "score", scoreManager.getComponent((JExpression)indexVar));
            ValueMapBuilder valueMapBuilder = TreeModelTranslator.createScoreDistribution(categories, scoreVar, context);
            context._return((JExpression)context._new(ProbabilityDistribution.class, valueMapBuilder));
        }
        finally {
            context.popScope();
        }
        return evaluateTreeModelMethod;
    }

    @Override
    public Map<FieldName, FieldInfo> getFieldInfos(Set<? extends PMMLObject> bodyObjects) {
        Map<FieldName, FieldInfo> fieldInfos = super.getFieldInfos(bodyObjects);
        fieldInfos = TreeModelTranslator.enhanceFieldInfos(bodyObjects, fieldInfos);
        return fieldInfos;
    }

    public static <S, ScoreManager extends ArrayManager<S>> void translateNode(TreeModel treeModel, Node root, ScoreManager scoreManager, Map<FieldName, FieldInfo> fieldInfos, TranslationContext context) {
        Object score = ((Function)((Object)scoreManager)).apply(root);
        Predicate predicate = root.getPredicate();
        if (!(predicate instanceof True)) {
            throw new UnsupportedElementException((PMMLObject)predicate);
        }
        TreeModelTranslator.translateNode(treeModel, null, root, scoreManager, fieldInfos, context);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static <S, ScoreManager extends ArrayManager<S>> void translateNode(TreeModel treeModel, Node parentNode, Node node, ScoreManager scoreManager, Map<FieldName, FieldInfo> fieldInfos, TranslationContext context) {
        JExpression scoreExpr;
        Object score = ((Function)((Object)scoreManager)).apply(node);
        Predicate predicate = node.getPredicate();
        Scope nodeScope = TreeModelTranslator.translatePredicate(treeModel, predicate, fieldInfos, context);
        if (node.hasNodes()) {
            context.pushScope(nodeScope);
            try {
                List children = node.getNodes();
                for (Node child : children) {
                    Predicate childPredicate = child.getPredicate();
                    if (childPredicate instanceof False) continue;
                    TreeModelTranslator.translateNode(treeModel, node, child, scoreManager, fieldInfos, context);
                    if (!(childPredicate instanceof True)) continue;
                    return;
                }
            }
            finally {
                context.popScope();
            }
            TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
            switch (noTrueChildStrategy) {
                case RETURN_NULL_PREDICTION: {
                    scoreExpr = NULL_RESULT;
                    break;
                }
                case RETURN_LAST_PREDICTION: {
                    if (score == null) {
                        scoreExpr = NULL_RESULT;
                        break;
                    }
                    int scoreIndex = scoreManager.getOrInsert(score);
                    scoreExpr = JExpr.lit((int)scoreIndex);
                    break;
                }
                default: {
                    throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)noTrueChildStrategy);
                }
            }
        } else {
            if (score == null) {
                throw new MissingAttributeException((PMMLObject)node, PMMLAttributes.COMPLEXNODE_SCORE);
            }
            int scoreIndex = scoreManager.getOrInsert(score);
            scoreExpr = JExpr.lit((int)scoreIndex);
        }
        JBlock nodeBlock = nodeScope.getBlock();
        nodeBlock._return(scoreExpr);
    }

    public static Scope translatePredicate(TreeModel treeModel, Predicate predicate, Map<FieldName, FieldInfo> fieldInfos, TranslationContext context) {
        JExpression valueExpr;
        OperableRef operableRef;
        FieldInfo fieldInfo;
        JBlock block = context.block();
        if (predicate instanceof SimplePredicate) {
            SimplePredicate simplePredicate = (SimplePredicate)predicate;
            fieldInfo = TreeModelTranslator.getFieldInfo(simplePredicate, fieldInfos);
            operableRef = context.ensureOperableVariable(fieldInfo);
            SimplePredicate.Operator operator = simplePredicate.getOperator();
            switch (operator) {
                case IS_MISSING: {
                    return TreeModelTranslator.createBranch(block, operableRef.isMissing());
                }
                case IS_NOT_MISSING: {
                    return TreeModelTranslator.createBranch(block, operableRef.isNotMissing());
                }
            }
            Object value = simplePredicate.getValue();
            switch (operator) {
                case EQUAL: {
                    valueExpr = operableRef.equalTo(value, context);
                    break;
                }
                case NOT_EQUAL: {
                    valueExpr = operableRef.notEqualTo(value, context);
                    break;
                }
                case LESS_THAN: {
                    valueExpr = operableRef.lessThan(value, context);
                    break;
                }
                case LESS_OR_EQUAL: {
                    valueExpr = operableRef.lessOrEqual(value, context);
                    break;
                }
                case GREATER_OR_EQUAL: {
                    valueExpr = operableRef.greaterOrEqual(value, context);
                    break;
                }
                case GREATER_THAN: {
                    valueExpr = operableRef.greaterThan(value, context);
                    break;
                }
                default: {
                    throw new UnsupportedAttributeException((PMMLObject)predicate, (Enum)operator);
                }
            }
        } else if (predicate instanceof SimpleSetPredicate) {
            SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate)predicate;
            fieldInfo = TreeModelTranslator.getFieldInfo(simpleSetPredicate, fieldInfos);
            operableRef = context.ensureOperableVariable(fieldInfo);
            ComplexArray complexArray = (ComplexArray)simpleSetPredicate.getArray();
            Collection values = complexArray.getValue();
            SimpleSetPredicate.BooleanOperator booleanOperator = simpleSetPredicate.getBooleanOperator();
            switch (booleanOperator) {
                case IS_IN: {
                    valueExpr = operableRef.isIn(values, context);
                    break;
                }
                case IS_NOT_IN: {
                    valueExpr = operableRef.isNotIn(values, context);
                    break;
                }
                default: {
                    throw new UnsupportedAttributeException((PMMLObject)predicate, (Enum)booleanOperator);
                }
            }
        } else {
            if (predicate instanceof True) {
                return TreeModelTranslator.createBranch(block, JExpr.TRUE);
            }
            if (predicate instanceof False) {
                return TreeModelTranslator.createBranch(block, JExpr.FALSE);
            }
            throw new UnsupportedElementException((PMMLObject)predicate);
        }
        JVar variable = operableRef.getVariable();
        TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NONE: {
                JType type;
                boolean isNonMissing = context.isNonMissing(variable);
                if (!isNonMissing && (type = operableRef.type()).isReference()) {
                    valueExpr = operableRef.isNotMissing().cand(valueExpr);
                }
                Scope result = TreeModelTranslator.createBranch(block, valueExpr);
                if (!isNonMissing) {
                    result.markNonMissing(variable);
                }
                return result;
            }
            case NULL_PREDICTION: {
                if (!context.isNonMissing(variable)) {
                    context._returnIf(operableRef.isMissing(), NULL_RESULT);
                    context.markNonMissing(variable);
                }
                return TreeModelTranslator.createBranch(block, valueExpr);
            }
        }
        throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)missingValueStrategy);
    }

    public static Map<FieldName, FieldInfo> enhanceFieldInfos(Set<? extends PMMLObject> bodyObjects, Map<FieldName, FieldInfo> fieldInfos) {
        FieldInfo fieldInfo;
        FieldName name;
        CountingActiveFieldFinder countingActiveFieldFinder = new CountingActiveFieldFinder();
        DiscreteValueFinder discreteValueFinder = new DiscreteValueFinder();
        for (PMMLObject pMMLObject : bodyObjects) {
            Node node = (Node)pMMLObject;
            countingActiveFieldFinder.applyTo(node);
            discreteValueFinder.applyTo((Visitable)node);
        }
        Map<FieldName, Set<Object>> discreteFieldValues = discreteValueFinder.getFieldValues();
        ArrayListMultimap arrayListMultimap = ArrayListMultimap.create();
        Set<Map.Entry<FieldName, FieldInfo>> entries = fieldInfos.entrySet();
        block8: for (Map.Entry entry : entries) {
            name = (FieldName)entry.getKey();
            fieldInfo = (FieldInfo)entry.getValue();
            Field<?> field = fieldInfo.getField();
            OpType opType = field.getOpType();
            DataType dataType = field.getDataType();
            fieldInfo.updateCount(countingActiveFieldFinder.getCount(name));
            block0 : switch (opType) {
                case CONTINUOUS: {
                    switch (dataType) {
                        case INTEGER: 
                        case FLOAT: 
                        case DOUBLE: {
                            FpPrimitiveEncoder encoder = FpPrimitiveEncoder.create(fieldInfo);
                            if (encoder instanceof TermFrequencyEncoder) {
                                TermFrequencyEncoder termFrequencyEncoder = (TermFrequencyEncoder)encoder;
                                FunctionInvocation.Tf tf = termFrequencyEncoder.getTf(fieldInfo);
                                List tokens = arrayListMultimap.get((Object)tf.getTextField());
                                int index = tokens.indexOf(tf.getTermTokens());
                                if (index < 0) {
                                    index = tokens.size();
                                    tokens.add(tf.getTermTokens());
                                }
                                termFrequencyEncoder.setIndex(index);
                            } else if (DataType.INTEGER.equals((Object)dataType)) {
                                encoder = null;
                            }
                            fieldInfo.setEncoder(encoder);
                            break block0;
                        }
                    }
                    break;
                }
                case CATEGORICAL: {
                    Set<Object> values = discreteFieldValues.get(name);
                    if (values == null || values.size() <= 0) continue block8;
                    OrdinalEncoder encoder = OrdinalEncoder.create(fieldInfo, values);
                    fieldInfo.setEncoder(encoder);
                    break;
                }
            }
        }
        for (Map.Entry entry : entries) {
            name = (FieldName)entry.getKey();
            fieldInfo = (FieldInfo)entry.getValue();
            Encoder encoder = fieldInfo.getEncoder();
            if (!(encoder instanceof TermFrequencyEncoder)) continue;
            TermFrequencyEncoder termFrequencyEncoder = (TermFrequencyEncoder)encoder;
            FunctionInvocation.Tf tf = termFrequencyEncoder.getTf(fieldInfo);
            termFrequencyEncoder.setVocabulary(arrayListMultimap.get((Object)tf.getTextField()));
        }
        return fieldInfos;
    }

    public static void ensureTextIndexFields(FieldInfo fieldInfo, TermFrequencyEncoder encoder, TranslationContext context) {
        JDefinedClass owner = context.getOwner();
        FunctionInvocation.Tf tf = encoder.getTf(fieldInfo);
        TextIndex textIndex = tf.getTextIndex();
        FieldName name = tf.getTextField();
        String textIndexName = IdentifierUtil.create("textIndex", (PMMLObject)textIndex, name);
        JFieldVar textIndexVar = (JFieldVar)owner.fields().get(textIndexName);
        if (textIndexVar == null) {
            JBinaryFileInitializer resourceInitializer = new JBinaryFileInitializer(IdentifierUtil.create(TextIndex.class.getSimpleName(), (PMMLObject)textIndex) + ".data", context);
            TextIndex localTextIndex = TextIndexUtil.toLocalTextIndex(textIndex, name);
            textIndexVar = owner.field(28, (JType)context.ref(TextIndex.class), textIndexName, (JExpression)PMMLObjectUtil.createObject((PMMLObject)localTextIndex, context));
            List[] terms = (List[])encoder.getVocabulary().stream().toArray(List[]::new);
            JFieldVar jFieldVar = resourceInitializer.initStringLists(IdentifierUtil.create("terms", (PMMLObject)textIndex, name), terms);
        }
    }

    private static ValueMapBuilder createScoreDistribution(Object[] categories, JVar scoreVar, TranslationContext context) {
        ValueMapBuilder valueMapBuilder = new ValueMapBuilder(context).construct("values");
        ValueFactoryRef valueFactoryRef = context.getValueFactoryVariable();
        for (int i = 0; i < categories.length; ++i) {
            JInvocation valueExpr = valueFactoryRef.newValue((JExpression)scoreVar.component(JExpr.lit((int)i)));
            valueMapBuilder.update("put", categories[i], valueExpr);
        }
        return valueMapBuilder;
    }

    private static Scope createBranch(JBlock block, JExpression testExpr) {
        JBlock thenBlock = block._if(testExpr)._then();
        return new Scope(thenBlock);
    }
}

