/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator;

import com.google.common.base.Function;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jpmml.evaluator.ClassificationAggregator;
import org.jpmml.evaluator.DoubleVector;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.HasProbability;

public class ProbabilityAggregator
extends ClassificationAggregator<String> {
    private List<HasProbability> hasProbabilities = null;
    private DoubleVector weights = null;

    public ProbabilityAggregator() {
        this(0);
    }

    public ProbabilityAggregator(int capacity) {
        super(capacity);
        if (capacity > 0) {
            this.hasProbabilities = new ArrayList<HasProbability>(capacity);
        }
        this.weights = new DoubleVector(0);
    }

    public void add(HasProbability hasProbability) {
        this.add(hasProbability, 1.0);
    }

    @Override
    public void add(HasProbability hasProbability, double weight) {
        if (weight < 0.0) {
            throw new IllegalArgumentException();
        }
        if (this.hasProbabilities != null) {
            this.hasProbabilities.add(hasProbability);
        }
        Set<String> categories = hasProbability.getCategoryValues();
        for (String category : categories) {
            Double probability = hasProbability.getProbability(category);
            this.add(category, weight != 1.0 ? probability * weight : probability);
        }
        this.weights.add(weight);
    }

    public Map<String, Double> averageMap() {
        return this.weightedAverageMap();
    }

    public Map<String, Double> weightedAverageMap() {
        Function<DoubleVector, Double> function = new Function<DoubleVector, Double>(){
            private double denominator;
            {
                this.denominator = ProbabilityAggregator.this.weights.sum();
            }

            public Double apply(DoubleVector values) {
                return values.sum() / this.denominator;
            }
        };
        return this.transform(function);
    }

    public Map<String, Double> maxMap(Collection<String> categories) {
        if (this.hasProbabilities == null) {
            throw new IllegalStateException();
        }
        Function<DoubleVector, Double> function = new Function<DoubleVector, Double>(){

            public Double apply(DoubleVector values) {
                return values.max();
            }
        };
        Map<String, Double> maxValues = this.transform(function);
        Map.Entry<String, Double> maxMaxValue = ProbabilityAggregator.getWinner(maxValues, categories);
        if (maxMaxValue == null) {
            return Collections.emptyMap();
        }
        String category = maxMaxValue.getKey();
        double maxProbability = maxMaxValue.getValue();
        ArrayList<HasProbability> contributors = new ArrayList<HasProbability>();
        DoubleVector values = this.get(category);
        for (int i = 0; i < values.size(); ++i) {
            double probability = values.get(i);
            if (probability != maxProbability) continue;
            HasProbability contributor = this.hasProbabilities.get(i);
            contributors.add(contributor);
        }
        return ProbabilityAggregator.averageMap(contributors);
    }

    public Map<String, Double> medianMap(Collection<String> categories) {
        if (this.hasProbabilities == null) {
            throw new IllegalStateException();
        }
        Function<DoubleVector, Double> function = new Function<DoubleVector, Double>(){

            public Double apply(DoubleVector values) {
                return values.median();
            }
        };
        Map<String, Double> medianValues = this.transform(function);
        Map.Entry<String, Double> maxMedianValue = ProbabilityAggregator.getWinner(medianValues, categories);
        if (maxMedianValue == null) {
            return Collections.emptyMap();
        }
        String category = maxMedianValue.getKey();
        double medianProbability = maxMedianValue.getValue();
        ArrayList<HasProbability> contributors = new ArrayList<HasProbability>();
        double minDifference = Double.MAX_VALUE;
        DoubleVector values = this.get(category);
        for (int i = 0; i < values.size(); ++i) {
            double probability = values.get(i);
            double difference = Math.abs(medianProbability - probability);
            if (difference < minDifference) {
                contributors.clear();
                minDifference = difference;
            }
            if (!(difference <= minDifference)) continue;
            HasProbability contributor = this.hasProbabilities.get(i);
            contributors.add(contributor);
        }
        return ProbabilityAggregator.averageMap(contributors);
    }

    private static Map.Entry<String, Double> getWinner(Map<String, Double> values, Collection<String> categories) {
        if (categories == null || categories.isEmpty()) {
            throw new EvaluationException();
        }
        AbstractMap.SimpleEntry<String, Double> maxEntry = null;
        for (String category : categories) {
            Double value = values.get(category);
            if (value == null || maxEntry != null && ((Double)maxEntry.getValue()).compareTo(value) >= 0) continue;
            maxEntry = new AbstractMap.SimpleEntry<String, Double>(category, value);
        }
        return maxEntry;
    }

    private static Map<String, Double> averageMap(List<HasProbability> hasProbabilities) {
        if (hasProbabilities.size() == 1) {
            HasProbability hasProbability = hasProbabilities.get(0);
            LinkedHashMap<String, Double> result = new LinkedHashMap<String, Double>();
            Set<String> categories = hasProbability.getCategoryValues();
            for (String category : categories) {
                Double probability = hasProbability.getProbability(category);
                result.put(category, probability);
            }
            return result;
        }
        ProbabilityAggregator aggregator = new ProbabilityAggregator();
        for (HasProbability hasProbability : hasProbabilities) {
            aggregator.add(hasProbability);
        }
        return aggregator.averageMap();
    }
}

