/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.collect.BiMap;
import com.google.common.collect.Maps;
import java.io.Serializable;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NnNormalizationMethodType;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.jpmml.evaluator.ArgumentUtil;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelManagerEvaluationContext;
import org.jpmml.evaluator.NeuronClassificationMap;
import org.jpmml.evaluator.NormalizationUtil;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.manager.NeuralNetworkManager;
import org.jpmml.manager.UnsupportedFeatureException;

public class NeuralNetworkEvaluator
extends NeuralNetworkManager
implements Evaluator {
    private BiMap<String, Entity> entities = null;
    private static final Normalizer SIMPLEMAX_NORMALIZER = new Normalizer(){

        @Override
        public double apply(double value) {
            return value;
        }
    };
    private static final Normalizer SOFTMAX_NORMALIZER = new Normalizer(){

        @Override
        public double apply(double value) {
            return Math.exp(value);
        }
    };

    public NeuralNetworkEvaluator(PMML pmml) {
        super(pmml);
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
    }

    @Override
    public BiMap<String, Entity> getEntityRegistry() {
        if (this.entities == null) {
            this.entities = super.getEntityRegistry();
        }
        return this.entities;
    }

    @Override
    public FieldValue prepare(FieldName name, Object value) {
        return ArgumentUtil.prepare(this.getDataField(name), this.getMiningField(name), value);
    }

    @Override
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> arguments) {
        Map<FieldName, Serializable> predictions;
        NeuralNetwork neuralNetwork = this.getModel();
        if (!neuralNetwork.isScorable()) {
            throw new InvalidResultException(neuralNetwork);
        }
        ModelManagerEvaluationContext context = new ModelManagerEvaluationContext(this);
        context.pushFrame(arguments);
        MiningFunctionType miningFunction = neuralNetwork.getFunctionName();
        switch (miningFunction) {
            case REGRESSION: {
                predictions = this.evaluateRegression(context);
                break;
            }
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException(neuralNetwork, miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ? extends Number> evaluateRegression(ModelManagerEvaluationContext context) {
        LinkedHashMap<FieldName, Double> result = Maps.newLinkedHashMap();
        Map<String, Double> entityOutputs = this.evaluateRaw(context);
        List<NeuralOutput> neuralOutputs = this.getOrCreateNeuralOutputs();
        for (NeuralOutput neuralOutput : neuralOutputs) {
            Double value;
            FieldName field;
            String id = neuralOutput.getOutputNeuron();
            Expression expression = this.getExpression(neuralOutput.getDerivedField());
            if (expression instanceof FieldRef) {
                FieldRef fieldRef = (FieldRef)expression;
                field = fieldRef.getField();
                value = entityOutputs.get(id);
                result.put(field, value);
                continue;
            }
            if (expression instanceof NormContinuous) {
                NormContinuous normContinuous = (NormContinuous)expression;
                field = normContinuous.getField();
                value = NormalizationUtil.denormalize(normContinuous, entityOutputs.get(id));
                result.put(field, value);
                continue;
            }
            throw new UnsupportedFeatureException(expression);
        }
        return TargetUtil.evaluateRegression(result, context);
    }

    private Map<FieldName, ? extends ClassificationMap> evaluateClassification(ModelManagerEvaluationContext context) {
        LinkedHashMap<FieldName, NeuronClassificationMap> result = Maps.newLinkedHashMap();
        BiMap<String, Entity> entities = this.getEntityRegistry();
        Map<String, Double> entityOutputs = this.evaluateRaw(context);
        List<NeuralOutput> neuralOutputs = this.getOrCreateNeuralOutputs();
        for (NeuralOutput neuralOutput : neuralOutputs) {
            String id = neuralOutput.getOutputNeuron();
            Expression expression = this.getExpression(neuralOutput.getDerivedField());
            if (expression instanceof NormDiscrete) {
                NormDiscrete normDiscrete = (NormDiscrete)expression;
                FieldName field = normDiscrete.getField();
                NeuronClassificationMap values = (NeuronClassificationMap)result.get(field);
                if (values == null) {
                    values = new NeuronClassificationMap();
                    result.put(field, values);
                }
                Entity entity = (Entity)entities.get(id);
                Double value = entityOutputs.get(id);
                values.put(entity, normDiscrete.getValue(), value);
                continue;
            }
            throw new UnsupportedFeatureException(expression);
        }
        return TargetUtil.evaluateClassification(result, context);
    }

    private Expression getExpression(DerivedField derivedField) {
        Expression expression = derivedField.getExpression();
        if (expression instanceof FieldRef) {
            FieldRef fieldRef = (FieldRef)expression;
            derivedField = this.resolveField(fieldRef.getField());
            if (derivedField != null) {
                return this.getExpression(derivedField);
            }
            return fieldRef;
        }
        return expression;
    }

    public Map<String, Double> evaluateRaw(EvaluationContext context) {
        LinkedHashMap<String, Double> result = Maps.newLinkedHashMap();
        List<NeuralInput> neuralInputs = this.getNeuralInputs();
        for (NeuralInput neuralInput : neuralInputs) {
            DerivedField derivedField = neuralInput.getDerivedField();
            FieldValue value = ExpressionUtil.evaluate(derivedField, context);
            if (value == null) {
                throw new MissingFieldException(derivedField.getName(), (PMMLObject)derivedField);
            }
            result.put(neuralInput.getId(), value.asNumber().doubleValue());
        }
        List<NeuralLayer> neuralLayers = this.getNeuralLayers();
        for (NeuralLayer neuralLayer : neuralLayers) {
            List<Neuron> neurons = neuralLayer.getNeurons();
            for (Neuron neuron : neurons) {
                double z = neuron.getBias();
                List<Connection> connections = neuron.getConnections();
                for (Connection connection : connections) {
                    double input = (Double)result.get(connection.getFrom());
                    z += input * connection.getWeight();
                }
                double output = this.activation(z, neuralLayer);
                result.put(neuron.getId(), output);
            }
            this.normalizeNeuronOutputs(neuralLayer, result);
        }
        return result;
    }

    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Map<String, Double> neuronOutputs) {
        NeuralNetwork neuralNetwork = this.getModel();
        PMMLObject locatable = neuralLayer;
        NnNormalizationMethodType normalizationMethod = neuralLayer.getNormalizationMethod();
        if (normalizationMethod == null) {
            locatable = neuralNetwork;
            normalizationMethod = neuralNetwork.getNormalizationMethod();
        }
        switch (normalizationMethod) {
            case NONE: {
                break;
            }
            case SIMPLEMAX: {
                this.normalizeNeuronOutputs(neuralLayer, SIMPLEMAX_NORMALIZER, neuronOutputs);
                break;
            }
            case SOFTMAX: {
                this.normalizeNeuronOutputs(neuralLayer, SOFTMAX_NORMALIZER, neuronOutputs);
                break;
            }
            default: {
                throw new UnsupportedFeatureException(locatable, normalizationMethod);
            }
        }
    }

    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Normalizer normalizer, Map<String, Double> neuronOutputs) {
        Double output;
        List<Neuron> neurons = neuralLayer.getNeurons();
        double sum = 0.0;
        for (Neuron neuron : neurons) {
            output = neuronOutputs.get(neuron.getId());
            sum += normalizer.apply(output);
        }
        for (Neuron neuron : neurons) {
            output = neuronOutputs.get(neuron.getId());
            Double normalizedOutput = normalizer.apply(output) / sum;
            neuronOutputs.put(neuron.getId(), normalizedOutput);
        }
    }

    private double activation(double z, NeuralLayer neuralLayer) {
        NeuralNetwork neuralNetwork = this.getModel();
        NeuralLayer locatable = neuralLayer;
        ActivationFunctionType activationFunction = neuralLayer.getActivationFunction();
        if (activationFunction == null) {
            locatable = neuralLayer;
            activationFunction = neuralNetwork.getActivationFunction();
        }
        switch (activationFunction) {
            case THRESHOLD: {
                Double threshold = neuralLayer.getThreshold();
                if (threshold == null) {
                    threshold = neuralNetwork.getThreshold();
                }
                return z > threshold ? 1.0 : 0.0;
            }
            case LOGISTIC: {
                return 1.0 / (1.0 + Math.exp(-z));
            }
            case TANH: {
                return (1.0 - Math.exp(-2.0 * z)) / (1.0 + Math.exp(-2.0 * z));
            }
            case IDENTITY: {
                return z;
            }
            case EXPONENTIAL: {
                return Math.exp(z);
            }
            case RECIPROCAL: {
                return 1.0 / z;
            }
            case SQUARE: {
                return z * z;
            }
            case GAUSS: {
                return Math.exp(-(z * z));
            }
            case SINE: {
                return Math.sin(z);
            }
            case COSINE: {
                return Math.cos(z);
            }
            case ELLIOTT: {
                return z / (1.0 + Math.abs(z));
            }
            case ARCTAN: {
                return Math.atan(z);
            }
        }
        throw new UnsupportedFeatureException(locatable, activationFunction);
    }

    private static interface Normalizer {
        public double apply(double var1);
    }
}

