/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.CategoricalPredictor;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PredictorTerm;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MissingResultException;
import org.jpmml.evaluator.ModelManagerEvaluationContext;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.ParameterUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.RegressionModelManager;
import org.jpmml.manager.UnsupportedFeatureException;

public class RegressionModelEvaluator
extends RegressionModelManager
implements Evaluator {
    public RegressionModelEvaluator(PMML pmml) {
        super(pmml);
    }

    public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
    }

    @Override
    public Object prepare(FieldName name, Object value) {
        return ParameterUtil.prepare(this.getDataField(name), this.getMiningField(name), value);
    }

    @Override
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> arguments) {
        Map<FieldName, Serializable> predictions;
        RegressionModel regressionModel = this.getModel();
        if (!regressionModel.isScorable()) {
            throw new InvalidResultException(regressionModel);
        }
        ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this, arguments);
        MiningFunctionType miningFunction = regressionModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException(regressionModel, miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ? extends Number> evaluateRegression(ModelManagerEvaluationContext context) {
        RegressionModel regressionModel = this.getModel();
        List<RegressionTable> regressionTables = this.getRegressionTables();
        if (regressionTables.size() != 1) {
            throw new InvalidFeatureException(regressionModel);
        }
        RegressionTable regressionTable = regressionTables.get(0);
        Double value = RegressionModelEvaluator.evaluateRegressionTable(regressionTable, context);
        if (value != null) {
            value = RegressionModelEvaluator.normalizeRegressionResult(regressionModel, value);
        }
        return TargetUtil.evaluateRegression(value, context);
    }

    private Map<FieldName, ? extends ClassificationMap> evaluateClassification(ModelManagerEvaluationContext context) {
        RegressionModel regressionModel = this.getModel();
        List<RegressionTable> regressionTables = this.getRegressionTables();
        if (regressionTables.size() < 1) {
            throw new InvalidFeatureException(regressionModel);
        }
        ClassificationMap result = new ClassificationMap(ClassificationMap.Type.PROBABILITY);
        double sumExp = 0.0;
        for (RegressionTable regressionTable : regressionTables) {
            String category = regressionTable.getTargetCategory();
            if (category == null) {
                throw new InvalidFeatureException(regressionTable);
            }
            Double value = RegressionModelEvaluator.evaluateRegressionTable(regressionTable, context);
            if (value == null) {
                throw new MissingResultException(regressionTable);
            }
            sumExp += Math.exp(value);
            result.put(category, value);
        }
        FieldName targetField = this.getTargetField();
        DataField dataField = this.getDataField(targetField);
        OpType opType = dataField.getOptype();
        switch (opType) {
            case CATEGORICAL: {
                break;
            }
            default: {
                throw new UnsupportedFeatureException(dataField, opType);
            }
        }
        Set entries = result.entrySet();
        for (Map.Entry entry : entries) {
            entry.setValue(RegressionModelEvaluator.normalizeClassificationResult(regressionModel, (Double)entry.getValue(), sumExp));
        }
        return TargetUtil.evaluateClassification(result, context);
    }

    private static Double evaluateRegressionTable(RegressionTable regressionTable, EvaluationContext context) {
        double result = 0.0;
        result += regressionTable.getIntercept();
        List<NumericPredictor> numericPredictors = regressionTable.getNumericPredictors();
        for (NumericPredictor numericPredictor : numericPredictors) {
            FieldName name = numericPredictor.getName();
            Object value = ExpressionUtil.evaluate(name, context);
            if (value == null) {
                context.addWarning("Missing argument \"" + name.getValue() + "\"");
                return null;
            }
            result += numericPredictor.getCoefficient() * Math.pow(((Number)value).doubleValue(), numericPredictor.getExponent());
        }
        List<CategoricalPredictor> categoricalPredictors = regressionTable.getCategoricalPredictors();
        for (CategoricalPredictor categoricalPredictor : categoricalPredictors) {
            FieldName name = categoricalPredictor.getName();
            Object value = ExpressionUtil.evaluate(name, context);
            if (value == null) {
                context.addWarning("Missing argument \"" + name.getValue() + "\"");
                continue;
            }
            boolean equals = ParameterUtil.equals(value, categoricalPredictor.getValue());
            result += categoricalPredictor.getCoefficient() * (equals ? 1.0 : 0.0);
        }
        List<PredictorTerm> predictorTerms = regressionTable.getPredictorTerms();
        for (PredictorTerm predictorTerm : predictorTerms) {
            double product = predictorTerm.getCoefficient();
            List<FieldRef> fieldRefs = predictorTerm.getFieldRefs();
            if (fieldRefs.size() < 1) {
                throw new InvalidFeatureException(predictorTerm);
            }
            for (FieldRef fieldRef : fieldRefs) {
                Object value = ExpressionUtil.evaluateFieldRef(fieldRef, context);
                if (value == null) {
                    return null;
                }
                product *= ((Number)value).doubleValue();
            }
            result += product;
        }
        return result;
    }

    private static Double normalizeRegressionResult(RegressionModel regressionModel, Double value) {
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        switch (regressionNormalizationMethod) {
            case NONE: {
                return value;
            }
            case SOFTMAX: 
            case LOGIT: {
                return 1.0 / (1.0 + Math.exp(-value.doubleValue()));
            }
            case EXP: {
                return Math.exp(value);
            }
        }
        throw new UnsupportedFeatureException(regressionModel, regressionNormalizationMethod);
    }

    private static Double normalizeClassificationResult(RegressionModel regressionModel, Double value, Double sumExp) {
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        switch (regressionNormalizationMethod) {
            case NONE: {
                return value;
            }
            case SOFTMAX: {
                return Math.exp(value) / sumExp;
            }
            case LOGIT: {
                return 1.0 / (1.0 + Math.exp(-value.doubleValue()));
            }
            case CLOGLOG: {
                return 1.0 - Math.exp(-Math.exp(value));
            }
            case LOGLOG: {
                return Math.exp(-Math.exp(-value.doubleValue()));
            }
        }
        throw new UnsupportedFeatureException(regressionModel, regressionNormalizationMethod);
    }
}

