/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.CategoricalPredictor;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.PredictorTerm;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
import org.dmg.pmml.TypeDefinitionField;
import org.jpmml.evaluator.ArgumentUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NormalDistributionUtil;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;

public class RegressionModelEvaluator
extends ModelEvaluator<RegressionModel> {
    public RegressionModelEvaluator(PMML pmml) {
        super(pmml, RegressionModel.class);
    }

    public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
    }

    @Override
    public String getSummary() {
        return "Regression";
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, Object> predictions;
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        if (!regressionModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)regressionModel);
        }
        MiningFunctionType miningFunction = regressionModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) {
        List regressionTables;
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        FieldName targetField = regressionModel.getTargetFieldName();
        if (targetField == null) {
            targetField = this.getTargetField();
        }
        if ((regressionTables = regressionModel.getRegressionTables()).size() != 1) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
        RegressionTable regressionTable = (RegressionTable)regressionTables.get(0);
        Double result = this.evaluateRegressionTable(regressionTable, context);
        if (result == null) {
            return TargetUtil.evaluateRegressionDefault(context);
        }
        result = this.normalizeRegressionResult(result);
        return TargetUtil.evaluateRegression(Collections.singletonMap(targetField, result), context);
    }

    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext context) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        FieldName targetField = regressionModel.getTargetFieldName();
        if (targetField == null) {
            targetField = this.getTargetField();
        }
        DataField dataField = this.getDataField(targetField);
        OpType opType = dataField.getOpType();
        switch (opType) {
            case CONTINUOUS: {
                throw new InvalidFeatureException((PMMLObject)dataField);
            }
            case CATEGORICAL: 
            case ORDINAL: {
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)dataField, (Enum<?>)opType);
            }
        }
        List regressionTables = regressionModel.getRegressionTables();
        if (regressionTables.size() < 1) {
            throw new InvalidFeatureException((PMMLObject)regressionModel);
        }
        List<String> targetCategories = ArgumentUtil.getTargetCategories((TypeDefinitionField)dataField);
        if (targetCategories.size() > 0 && targetCategories.size() != regressionTables.size()) {
            throw new InvalidFeatureException((PMMLObject)dataField);
        }
        LinkedHashMap<String, Double> values = new LinkedHashMap<String, Double>();
        for (RegressionTable regressionTable : regressionTables) {
            String targetCategory = regressionTable.getTargetCategory();
            if (targetCategory == null) {
                throw new InvalidFeatureException((PMMLObject)regressionTable);
            }
            Double value = this.evaluateRegressionTable(regressionTable, context);
            if (value == null) {
                return TargetUtil.evaluateClassificationDefault(context);
            }
            values.put(targetCategory, value);
        }
        switch (opType) {
            case CATEGORICAL: {
                this.computeCategoricalProbabilities(values);
                break;
            }
            case ORDINAL: {
                this.computeOrdinalProbabilities(values, targetCategories);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)dataField, (Enum<?>)opType);
            }
        }
        ProbabilityDistribution result = new ProbabilityDistribution();
        result.putAll(values);
        return TargetUtil.evaluateClassification(Collections.singletonMap(targetField, result), context);
    }

    private Double evaluateRegressionTable(RegressionTable regressionTable, EvaluationContext context) {
        double result = 0.0;
        result += regressionTable.getIntercept();
        List numericPredictors = regressionTable.getNumericPredictors();
        for (NumericPredictor numericPredictor : numericPredictors) {
            FieldName name = numericPredictor.getName();
            FieldValue value = ExpressionUtil.evaluate(name, context);
            if (value == null) {
                context.addWarning("Missing argument \"" + name.getValue() + "\"");
                return null;
            }
            result += numericPredictor.getCoefficient() * Math.pow(value.asNumber().doubleValue(), numericPredictor.getExponent());
        }
        List categoricalPredictors = regressionTable.getCategoricalPredictors();
        for (CategoricalPredictor categoricalPredictor : categoricalPredictors) {
            FieldName name = categoricalPredictor.getName();
            FieldValue value = ExpressionUtil.evaluate(name, context);
            if (value == null) {
                context.addWarning("Missing argument \"" + name.getValue() + "\"");
                continue;
            }
            boolean equals = value.equalsString(categoricalPredictor.getValue());
            result += categoricalPredictor.getCoefficient() * (equals ? 1.0 : 0.0);
        }
        List predictorTerms = regressionTable.getPredictorTerms();
        for (PredictorTerm predictorTerm : predictorTerms) {
            double product = predictorTerm.getCoefficient();
            List fieldRefs = predictorTerm.getFieldRefs();
            if (fieldRefs.size() < 1) {
                throw new InvalidFeatureException((PMMLObject)predictorTerm);
            }
            for (FieldRef fieldRef : fieldRefs) {
                FieldValue value = ExpressionUtil.evaluateFieldRef(fieldRef, context);
                if (value == null) {
                    return null;
                }
                product *= value.asNumber().doubleValue();
            }
            result += product;
        }
        return result;
    }

    private Double normalizeRegressionResult(Double value) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        switch (regressionNormalizationMethod) {
            case NONE: {
                return value;
            }
            case SOFTMAX: 
            case LOGIT: {
                return 1.0 / (1.0 + Math.exp(-value.doubleValue()));
            }
            case EXP: {
                return Math.exp(value);
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)regressionNormalizationMethod);
    }

    private void computeCategoricalProbabilities(Map<String, Double> values) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        switch (regressionNormalizationMethod) {
            case NONE: {
                return;
            }
            case SIMPLEMAX: {
                Classification.normalize(values);
                return;
            }
            case SOFTMAX: {
                Classification.normalizeSoftMax(values);
                return;
            }
        }
        Set<Map.Entry<String, Double>> entries = values.entrySet();
        for (Map.Entry entry : entries) {
            entry.setValue(this.normalizeClassificationResult((Double)entry.getValue()));
        }
        Classification.normalize(values);
    }

    private void computeOrdinalProbabilities(Map<String, Double> values, List<String> targetCategories) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        switch (regressionNormalizationMethod) {
            case NONE: {
                return;
            }
            case SOFTMAX: 
            case SIMPLEMAX: {
                throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)regressionNormalizationMethod);
            }
        }
        Set<Map.Entry<String, Double>> entries = values.entrySet();
        for (Map.Entry entry : entries) {
            entry.setValue(this.normalizeClassificationResult((Double)entry.getValue()));
        }
        RegressionModelEvaluator.calculateCategoryProbabilities(values, targetCategories);
    }

    private Double normalizeClassificationResult(Double value) {
        RegressionModel regressionModel = (RegressionModel)this.getModel();
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        switch (regressionNormalizationMethod) {
            case LOGIT: {
                return 1.0 / (1.0 + Math.exp(-value.doubleValue()));
            }
            case PROBIT: {
                return NormalDistributionUtil.cumulativeProbability(value);
            }
            case CLOGLOG: {
                return 1.0 - Math.exp(-Math.exp(value));
            }
            case LOGLOG: {
                return Math.exp(-Math.exp(-value.doubleValue()));
            }
            case CAUCHIT: {
                return 0.5 + 0.3183098861837907 * Math.atan(value);
            }
        }
        throw new UnsupportedFeatureException((PMMLObject)regressionModel, (Enum<?>)regressionNormalizationMethod);
    }

    public static void calculateCategoryProbabilities(Map<String, Double> map, List<String> categories) {
        double offset = 0.0;
        for (int i = 0; i < categories.size() - 1; ++i) {
            String category = categories.get(i);
            Double cumulativeProbability = map.get(category);
            if (cumulativeProbability == null || cumulativeProbability > 1.0) {
                throw new EvaluationException();
            }
            Double probability = cumulativeProbability - offset;
            if (probability < 0.0) {
                throw new EvaluationException();
            }
            map.put(category, probability);
            offset = cumulativeProbability;
        }
        if (categories.size() > 1) {
            String category = categories.get(categories.size() - 1);
            map.put(category, 1.0 - offset);
        }
    }
}

