/*
 * Decompiled with CFR 0.152.
 */
package org.nlpub.watset.eval;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.nlpub.watset.eval.PrecisionRecall;

public class NormalizedModifiedPurity<V> {
    final boolean normalized;
    final boolean modified;

    public static <V> List<Map<V, Double>> transform(List<Collection<V>> clusters) {
        return clusters.stream().map(cluster -> cluster.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.reducing(0.0, e -> 1.0, Double::sum)))).collect(Collectors.toList());
    }

    public static <V> List<Map<V, Double>> normalize(Collection<Map<V, Double>> clusters) {
        HashMap counter = new HashMap();
        clusters.stream().flatMap(cluster -> cluster.entrySet().stream()).forEach(entry -> counter.put(entry.getKey(), counter.getOrDefault(entry.getKey(), 0.0) + (Double)entry.getValue()));
        List<Map<V, Double>> normalized = clusters.stream().map(cluster -> {
            Map<Object, Double> normalizedCluster = cluster.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> (Double)entry.getValue() / (Double)counter.get(entry.getKey())));
            if (cluster.size() != normalizedCluster.size()) {
                throw new IllegalArgumentException("Cluster size changed");
            }
            return normalizedCluster;
        }).collect(Collectors.toList());
        if (clusters.size() != normalized.size()) {
            throw new IllegalArgumentException("Collection size changed");
        }
        return normalized;
    }

    public static <V> PrecisionRecall evaluate(NormalizedModifiedPurity<V> precision, NormalizedModifiedPurity<V> recall, Collection<Map<V, Double>> clusters, Collection<Map<V, Double>> classes) {
        double nmPU = precision.purity(Objects.requireNonNull(clusters), Objects.requireNonNull(classes));
        double niPU = recall.purity(classes, clusters);
        return new PrecisionRecall(nmPU, niPU);
    }

    public NormalizedModifiedPurity() {
        this(true, true);
    }

    public NormalizedModifiedPurity(boolean normalized, boolean modified) {
        this.normalized = normalized;
        this.modified = modified;
    }

    public double purity(Collection<Map<V, Double>> clusters, Collection<Map<V, Double>> classes) {
        double denominator = clusters.stream().mapToInt(Map::size).sum();
        if (this.normalized) {
            denominator = clusters.parallelStream().mapToDouble(cluster -> cluster.values().stream().mapToDouble(Double::doubleValue).sum()).sum();
        }
        if (denominator == 0.0) {
            return 0.0;
        }
        double numerator = clusters.parallelStream().mapToDouble(cluster -> this.score((Map<V, Double>)cluster, classes)).sum();
        return numerator / denominator;
    }

    public double score(Map<V, Double> cluster, Collection<Map<V, Double>> classes) {
        return classes.stream().mapToDouble(klass -> this.delta(cluster, (Map<V, Double>)klass)).max().orElse(0.0);
    }

    public double delta(Map<V, Double> cluster, Map<V, Double> klass) {
        if (this.modified && cluster.size() <= 1) {
            return 0.0;
        }
        HashMap<V, Double> intersection = new HashMap<V, Double>(cluster);
        intersection.keySet().retainAll(klass.keySet());
        if (intersection.isEmpty()) {
            return 0.0;
        }
        if (!this.normalized) {
            return intersection.size();
        }
        return intersection.values().stream().mapToDouble(Double::doubleValue).sum();
    }
}

