/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.regression;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.regression.RegressionModelUtil;

public class RegressionModelEvaluator
extends ModelEvaluator<RegressionModel> {
    public RegressionModelEvaluator(PMML pmml) {
        this(pmml, RegressionModelEvaluator.selectModel(pmml, RegressionModel.class));
    }

    public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
        if (!regressionModel.hasRegressionTables()) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
    }

    @Override
    public String getSummary() {
        return "Regression";
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, Object> predictions;
        ValueFactory<?> valueFactory;
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        if (!regressionModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)regressionModel);
        }
        MathContext mathContext = regressionModel.getMathContext();
        switch (mathContext) {
            case FLOAT: 
            case DOUBLE: {
                valueFactory = this.getValueFactory();
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)mathContext);
            }
        }
        MiningFunction miningFunction = regressionModel.getMiningFunction();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(valueFactory, context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(valueFactory, context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        TargetField targetField = this.getTargetField();
        FieldName targetFieldName = regressionModel.getTargetFieldName();
        if (targetFieldName != null && !Objects.equals(targetField.getName(), targetFieldName)) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
        List regressionTables = regressionModel.getRegressionTables();
        if (regressionTables.size() != 1) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
        RegressionTable regressionTable = (RegressionTable)regressionTables.get(0);
        Value<V> result = this.evaluateRegressionTable(valueFactory, regressionTable, context);
        if (result == null) {
            return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
        }
        RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
        switch (normalizationMethod) {
            case NONE: 
            case SOFTMAX: 
            case LOGIT: 
            case EXP: 
            case PROBIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case CAUCHIT: {
                RegressionModelUtil.normalizeRegressionResult(result, normalizationMethod);
                break;
            }
            case SIMPLEMAX: {
                throw new InvalidFeatureException((PMMLObject)regressionModel);
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)normalizationMethod);
            }
        }
        return TargetUtil.evaluateRegression(targetField, result);
    }

    private <V extends Number> Map<FieldName, ? extends Classification<V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext context) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        TargetField targetField = this.getTargetField();
        FieldName targetFieldName = regressionModel.getTargetFieldName();
        if (targetFieldName != null && !Objects.equals(targetField.getName(), targetFieldName)) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
        DataField dataField = targetField.getDataField();
        OpType opType = dataField.getOpType();
        switch (opType) {
            case CONTINUOUS: {
                throw new InvalidFeatureException((PMMLObject)dataField);
            }
            case CATEGORICAL: 
            case ORDINAL: {
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)dataField, (Enum<?>)opType);
            }
        }
        List regressionTables = regressionModel.getRegressionTables();
        if (regressionTables.size() < 2) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
        List<String> targetCategories = FieldValueUtil.getTargetCategories((TypeDefinitionField)dataField);
        if (targetCategories.size() > 0 && targetCategories.size() != regressionTables.size()) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
        ValueMap<String, Value<V>> values = new ValueMap<String, Value<V>>(2 * regressionTables.size());
        for (RegressionTable regressionTable : regressionTables) {
            String targetCategory = regressionTable.getTargetCategory();
            if (targetCategory == null) {
                throw new InvalidFeatureException((PMMLObject)regressionTable);
            }
            if (targetCategories.size() > 0 && targetCategories.indexOf(targetCategory) < 0) {
                throw new InvalidFeatureException((PMMLObject)regressionTable);
            }
            Value<V> value = this.evaluateRegressionTable(valueFactory, regressionTable, context);
            if (value == null) {
                return TargetUtil.evaluateClassificationDefault(valueFactory, targetField);
            }
            values.put(targetCategory, value);
        }
        RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
        switch (normalizationMethod) {
            case NONE: {
                if (OpType.CATEGORICAL.equals((Object)opType)) {
                    if (values.size() == 2) {
                        RegressionModelUtil.computeBinomialProbabilities(values, normalizationMethod);
                        break;
                    }
                    RegressionModelUtil.computeMultinomialProbabilities(values, normalizationMethod);
                    break;
                }
                RegressionModelUtil.computeOrdinalProbabilities(values, normalizationMethod);
                break;
            }
            case SOFTMAX: 
            case SIMPLEMAX: {
                if (OpType.CATEGORICAL.equals((Object)opType)) {
                    if (values.size() == 2 && RegressionModelEvaluator.isDefault((RegressionTable)regressionTables.get(1)) && RegressionModel.NormalizationMethod.SOFTMAX.equals((Object)normalizationMethod)) {
                        RegressionModelUtil.computeBinomialProbabilities(values, RegressionModel.NormalizationMethod.LOGIT);
                        break;
                    }
                    RegressionModelUtil.computeMultinomialProbabilities(values, normalizationMethod);
                    break;
                }
                throw new InvalidFeatureException((PMMLObject)regressionModel);
            }
            case LOGIT: 
            case PROBIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case CAUCHIT: {
                if (OpType.CATEGORICAL.equals((Object)opType)) {
                    if (values.size() == 2) {
                        RegressionModelUtil.computeBinomialProbabilities(values, normalizationMethod);
                        break;
                    }
                    if (values.size() > 2 && RegressionModel.NormalizationMethod.LOGIT.equals((Object)normalizationMethod)) {
                        RegressionModelUtil.computeMultinomialProbabilities(values, normalizationMethod);
                        break;
                    }
                    throw new InvalidFeatureException((PMMLObject)regressionModel);
                }
                RegressionModelUtil.computeOrdinalProbabilities(values, normalizationMethod);
                break;
            }
            case EXP: {
                throw new InvalidFeatureException((PMMLObject)regressionModel);
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)normalizationMethod);
            }
        }
        ProbabilityDistribution result = new ProbabilityDistribution(values);
        return TargetUtil.evaluateClassification(targetField, result);
    }

    private <V extends Number> Value<V> evaluateRegressionTable(ValueFactory<V> valueFactory, RegressionTable regressionTable, EvaluationContext context) {
        Value<V> result = valueFactory.newValue();
        if (regressionTable.hasNumericPredictors()) {
            List numericPredictors = regressionTable.getNumericPredictors();
            for (NumericPredictor numericPredictor : numericPredictors) {
                FieldValue value = context.evaluate(numericPredictor.getName());
                if (value == null) {
                    return null;
                }
                int exponent = numericPredictor.getExponent();
                if (exponent != 1) {
                    result.add(numericPredictor.getCoefficient(), value.asNumber(), exponent);
                    continue;
                }
                result.add(numericPredictor.getCoefficient(), value.asNumber());
            }
        }
        if (regressionTable.hasCategoricalPredictors()) {
            FieldName matchedName = null;
            List categoricalPredictors = regressionTable.getCategoricalPredictors();
            for (CategoricalPredictor categoricalPredictor : categoricalPredictors) {
                FieldValue value;
                FieldName name = categoricalPredictor.getName();
                if (matchedName != null) {
                    if (matchedName.equals((Object)name)) continue;
                    matchedName = null;
                }
                if ((value = context.evaluate(name)) == null) {
                    matchedName = name;
                    continue;
                }
                boolean equals = value.equals((HasValue<?>)categoricalPredictor);
                if (!equals) continue;
                matchedName = name;
                result.add(categoricalPredictor.getCoefficient());
            }
        }
        if (regressionTable.hasPredictorTerms()) {
            ArrayList<Number> factors = new ArrayList<Number>();
            List predictorTerms = regressionTable.getPredictorTerms();
            for (PredictorTerm predictorTerm : predictorTerms) {
                factors.clear();
                List fieldRefs = predictorTerm.getFieldRefs();
                for (FieldRef fieldRef : fieldRefs) {
                    FieldValue value = ExpressionUtil.evaluate((Expression)fieldRef, context);
                    if (value == null) {
                        return null;
                    }
                    factors.add(value.asNumber());
                }
                result.add(predictorTerm.getCoefficient(), factors);
            }
        }
        result.add(regressionTable.getIntercept());
        return result;
    }

    private static boolean isDefault(RegressionTable regressionTable) {
        if (regressionTable.hasExtensions()) {
            return false;
        }
        if (regressionTable.hasNumericPredictors() || regressionTable.hasCategoricalPredictors() || regressionTable.hasPredictorTerms()) {
            return false;
        }
        return regressionTable.getIntercept() == 0.0;
    }
}

