/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import java.io.Serializable;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.CategoricalPredictor;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PredictorTerm;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.ModelManagerEvaluationContext;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.ParameterUtil;
import org.jpmml.manager.RegressionModelManager;
import org.jpmml.manager.UnsupportedFeatureException;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RegressionModelEvaluator
extends RegressionModelManager
implements Evaluator {
    public RegressionModelEvaluator(PMML pmml) {
        super(pmml);
    }

    public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
    }

    public RegressionModelEvaluator(RegressionModelManager parent) {
        this(parent.getPmml(), parent.getModel());
    }

    @Override
    public Object prepare(FieldName name, Object value) {
        return ParameterUtil.prepare(this.getDataField(name), this.getMiningField(name), value);
    }

    @Override
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> parameters) {
        Map<FieldName, Serializable> predictions;
        RegressionModel regressionModel = this.getModel();
        ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this, parameters);
        MiningFunctionType miningFunction = regressionModel.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException(miningFunction);
            }
        }
        return OutputUtil.evaluate(this, parameters, predictions);
    }

    public Map<FieldName, Double> evaluateRegression(EvaluationContext context) {
        RegressionModel regressionModel = this.getModel();
        List<RegressionTable> regressionTables = this.getRegressionTables();
        if (regressionTables.size() != 1) {
            throw new EvaluationException();
        }
        RegressionTable regressionTable = regressionTables.get(0);
        Double value = RegressionModelEvaluator.evaluateRegressionTable(regressionTable, context);
        FieldName name = this.getTarget();
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        value = RegressionModelEvaluator.normalizeRegressionResult(regressionNormalizationMethod, value);
        return Collections.singletonMap(name, value);
    }

    public Map<FieldName, ClassificationMap> evaluateClassification(EvaluationContext context) {
        RegressionModel regressionModel = this.getModel();
        List<RegressionTable> regressionTables = this.getRegressionTables();
        if (regressionTables.size() < 1) {
            throw new EvaluationException();
        }
        double sumExp = 0.0;
        ClassificationMap values = new ClassificationMap();
        for (RegressionTable regressionTable : regressionTables) {
            Double value = RegressionModelEvaluator.evaluateRegressionTable(regressionTable, context);
            sumExp += Math.exp(value);
            values.put(regressionTable.getTargetCategory(), value);
        }
        FieldName name = this.getTarget();
        DataField dataField = this.getDataField(name);
        OpType opType = dataField.getOptype();
        switch (opType) {
            case CATEGORICAL: {
                break;
            }
            default: {
                throw new UnsupportedFeatureException(opType);
            }
        }
        RegressionNormalizationMethodType regressionNormalizationMethod = regressionModel.getNormalizationMethod();
        Set entries = values.entrySet();
        for (Map.Entry entry : entries) {
            entry.setValue(RegressionModelEvaluator.normalizeClassificationResult(regressionNormalizationMethod, (Double)entry.getValue(), sumExp));
        }
        return Collections.singletonMap(name, values);
    }

    private static Double evaluateRegressionTable(RegressionTable regressionTable, EvaluationContext context) {
        double result = 0.0;
        result += regressionTable.getIntercept();
        List<NumericPredictor> numericPredictors = regressionTable.getNumericPredictors();
        for (NumericPredictor numericPredictor : numericPredictors) {
            Object value = ExpressionUtil.evaluate(numericPredictor.getName(), context);
            if (value == null) {
                return null;
            }
            result += numericPredictor.getCoefficient() * Math.pow(((Number)value).doubleValue(), numericPredictor.getExponent().doubleValue());
        }
        List<CategoricalPredictor> categoricalPredictors = regressionTable.getCategoricalPredictors();
        for (CategoricalPredictor categoricalPredictor : categoricalPredictors) {
            Object value = ExpressionUtil.evaluate(categoricalPredictor.getName(), context);
            if (value == null) continue;
            boolean equals = ParameterUtil.equals(value, categoricalPredictor.getValue());
            result += categoricalPredictor.getCoefficient() * (equals ? 1.0 : 0.0);
        }
        List<PredictorTerm> predictorTerms = regressionTable.getPredictorTerms();
        Iterator<PredictorTerm> i$ = predictorTerms.iterator();
        if (i$.hasNext()) {
            PredictorTerm predictorTerm = i$.next();
            throw new UnsupportedFeatureException(predictorTerm);
        }
        return result;
    }

    private static Double normalizeRegressionResult(RegressionNormalizationMethodType regressionNormalizationMethod, Double value) {
        switch (regressionNormalizationMethod) {
            case NONE: {
                return value;
            }
            case SOFTMAX: 
            case LOGIT: {
                return 1.0 / (1.0 + Math.exp(-value.doubleValue()));
            }
            case EXP: {
                return Math.exp(value);
            }
        }
        throw new UnsupportedFeatureException(regressionNormalizationMethod);
    }

    private static Double normalizeClassificationResult(RegressionNormalizationMethodType regressionNormalizationMethod, Double value, Double sumExp) {
        switch (regressionNormalizationMethod) {
            case NONE: {
                return value;
            }
            case SOFTMAX: {
                return Math.exp(value) / sumExp;
            }
            case LOGIT: {
                return 1.0 / (1.0 + Math.exp(-value.doubleValue()));
            }
            case CLOGLOG: {
                return 1.0 - Math.exp(-Math.exp(value));
            }
            case LOGLOG: {
                return Math.exp(-Math.exp(-value.doubleValue()));
            }
        }
        throw new UnsupportedFeatureException(regressionNormalizationMethod);
    }
}

