/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.BayesInput;
import org.dmg.pmml.BayesInputs;
import org.dmg.pmml.BayesOutput;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Extension;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NaiveBayesModel;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.PairCounts;
import org.dmg.pmml.TargetValueCount;
import org.dmg.pmml.TargetValueCounts;
import org.dmg.pmml.TargetValueStat;
import org.dmg.pmml.TargetValueStats;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.DefaultClassificationMap;
import org.jpmml.evaluator.DiscretizationUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;
import org.jpmml.model.ExtensionUtil;

public class NaiveBayesModelEvaluator
extends ModelEvaluator<NaiveBayesModel> {
    private static final LoadingCache<NaiveBayesModel, List<BayesInput>> bayesInputCache = CacheBuilder.newBuilder().weakKeys().build((CacheLoader)new CacheLoader<NaiveBayesModel, List<BayesInput>>(){

        public List<BayesInput> load(NaiveBayesModel naiveBayesModel) {
            return NaiveBayesModelEvaluator.parseBayesInputs(naiveBayesModel);
        }
    });
    private static final LoadingCache<NaiveBayesModel, Map<FieldName, Map<String, Double>>> countCache = CacheBuilder.newBuilder().weakKeys().build((CacheLoader)new CacheLoader<NaiveBayesModel, Map<FieldName, Map<String, Double>>>(){

        public Map<FieldName, Map<String, Double>> load(NaiveBayesModel naiveBayesModel) {
            return NaiveBayesModelEvaluator.calculateCounts(naiveBayesModel);
        }
    });

    public NaiveBayesModelEvaluator(PMML pmml) {
        this(pmml, (NaiveBayesModel)NaiveBayesModelEvaluator.find((List)pmml.getModels(), NaiveBayesModel.class));
    }

    public NaiveBayesModelEvaluator(PMML pmml, NaiveBayesModel naiveBayesModel) {
        super(pmml, naiveBayesModel);
    }

    public String getSummary() {
        return "Naive Bayes model";
    }

    @Override
    public Map<FieldName, ?> evaluate(ModelEvaluationContext context) {
        Map<FieldName, ? extends ClassificationMap<?>> predictions;
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel)this.getModel();
        if (!naiveBayesModel.isScorable()) {
            throw new InvalidResultException((PMMLObject)naiveBayesModel);
        }
        MiningFunctionType miningFunction = naiveBayesModel.getFunctionName();
        switch (miningFunction) {
            case CLASSIFICATION: {
                predictions = this.evaluateClassification(context);
                break;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)naiveBayesModel, (Enum)miningFunction);
            }
        }
        return OutputUtil.evaluate(predictions, context);
    }

    private Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ModelEvaluationContext context) {
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel)this.getModel();
        double threshold = naiveBayesModel.getThreshold();
        DefaultClassificationMap<String> result = new DefaultClassificationMap<String>();
        Map<FieldName, Map<String, Double>> countsMap = this.getCountsMap();
        List<BayesInput> bayesInputs = this.getValue(bayesInputCache);
        for (BayesInput bayesInput : bayesInputs) {
            TargetValueCounts targetValueCounts;
            FieldName name = bayesInput.getFieldName();
            FieldValue value = ExpressionUtil.evaluate(name, (EvaluationContext)context);
            if (value == null) continue;
            TargetValueStats targetValueStats = NaiveBayesModelEvaluator.getTargetValueStats(bayesInput);
            if (targetValueStats != null) {
                this.calculateContinuousProbabilities(value, targetValueStats, threshold, result);
                continue;
            }
            Map<String, Double> counts = countsMap.get(name);
            DerivedField derivedField = bayesInput.getDerivedField();
            if (derivedField != null) {
                Expression expression = derivedField.getExpression();
                if (!(expression instanceof Discretize)) {
                    throw new InvalidFeatureException((PMMLObject)derivedField);
                }
                Discretize discretize = (Discretize)expression;
                if ((value = DiscretizationUtil.discretize(discretize, value)) == null) {
                    throw new EvaluationException();
                }
                value = FieldValueUtil.refine((Field)derivedField, value);
            }
            if ((targetValueCounts = NaiveBayesModelEvaluator.getTargetValueCounts(bayesInput, value)) == null) continue;
            this.calculateDiscreteProbabilities(counts, targetValueCounts, threshold, result);
        }
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        this.calculatePriorProbabilities(bayesOutput.getTargetValueCounts(), result);
        Double max = (Double)Collections.max(result.values());
        Set entries = result.entrySet();
        for (Map.Entry entry : entries) {
            entry.setValue(Math.exp((Double)entry.getValue() - max));
        }
        result.normalizeValues();
        return TargetUtil.evaluateClassification(Collections.singletonMap(bayesOutput.getFieldName(), result), context);
    }

    private void calculateContinuousProbabilities(FieldValue value, TargetValueStats targetValueStats, double threshold, Map<String, Double> probabilities) {
        double x = value.asNumber().doubleValue();
        for (TargetValueStat targetValueStat : targetValueStats) {
            String targetValue = targetValueStat.getValue();
            ContinuousDistribution distribution = targetValueStat.getContinuousDistribution();
            if (!(distribution instanceof GaussianDistribution)) {
                throw new InvalidFeatureException((PMMLObject)targetValueStat);
            }
            GaussianDistribution gaussianDistribution = (GaussianDistribution)distribution;
            double mean = gaussianDistribution.getMean();
            double variance = gaussianDistribution.getVariance();
            double probability = Math.max(Math.exp(-Math.pow(x - mean, 2.0) / (2.0 * variance)) / Math.sqrt(Math.PI * 2 * variance), threshold);
            NaiveBayesModelEvaluator.updateSum(targetValue, Math.log(probability), probabilities);
        }
    }

    private void calculateDiscreteProbabilities(Map<String, Double> counts, TargetValueCounts targetValueCounts, double threshold, Map<String, Double> probabilities) {
        for (TargetValueCount targetValueCount : targetValueCounts) {
            String targetValue = targetValueCount.getValue();
            Double count = counts.get(targetValue);
            double probability = Math.max(targetValueCount.getCount() / count, threshold);
            NaiveBayesModelEvaluator.updateSum(targetValue, Math.log(probability), probabilities);
        }
    }

    private void calculatePriorProbabilities(TargetValueCounts targetValueCounts, Map<String, Double> probabilities) {
        for (TargetValueCount targetValueCount : targetValueCounts) {
            String targetValue = targetValueCount.getValue();
            NaiveBayesModelEvaluator.updateSum(targetValue, Math.log(targetValueCount.getCount()), probabilities);
        }
    }

    protected Map<FieldName, Map<String, Double>> getCountsMap() {
        return this.getValue(countCache);
    }

    private static Map<FieldName, Map<String, Double>> calculateCounts(NaiveBayesModel naiveBayesModel) {
        LinkedHashMap result = Maps.newLinkedHashMap();
        List<BayesInput> bayesInputs = CacheUtil.getValue(naiveBayesModel, bayesInputCache);
        for (BayesInput bayesInput : bayesInputs) {
            FieldName name = bayesInput.getFieldName();
            LinkedHashMap counts = Maps.newLinkedHashMap();
            List pairCounts = bayesInput.getPairCounts();
            for (PairCounts pairCount : pairCounts) {
                TargetValueCounts targetValueCounts = pairCount.getTargetValueCounts();
                for (TargetValueCount targetValueCount : targetValueCounts) {
                    NaiveBayesModelEvaluator.updateSum(targetValueCount.getValue(), targetValueCount.getCount(), counts);
                }
            }
            result.put(name, counts);
        }
        return result;
    }

    private static List<BayesInput> parseBayesInputs(NaiveBayesModel naiveBayesModel) {
        ArrayList result = Lists.newArrayList();
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        List extensions = bayesInputs.getExtensions();
        for (Extension extension : extensions) {
            BayesInput bayesInput;
            try {
                bayesInput = (BayesInput)ExtensionUtil.getExtension((Extension)extension, BayesInput.class);
            }
            catch (JAXBException je) {
                throw new InvalidFeatureException((PMMLObject)extension);
            }
            if (bayesInput == null) continue;
            result.add(bayesInput);
        }
        result.addAll(bayesInputs.getBayesInputs());
        return result;
    }

    private static void updateSum(String key, Double value, Map<String, Double> counts) {
        Double count = counts.get(key);
        if (count == null) {
            count = 0.0;
        }
        counts.put(key, count + value);
    }

    private static TargetValueStats getTargetValueStats(BayesInput bayesInput) {
        return bayesInput.getTargetValueStats();
    }

    private static TargetValueCounts getTargetValueCounts(BayesInput bayesInput, FieldValue value) {
        List pairCounts = bayesInput.getPairCounts();
        for (PairCounts pairCount : pairCounts) {
            if (!value.equalsString(pairCount.getValue())) continue;
            return pairCount.getTargetValueCounts();
        }
        return null;
    }
}

