/*
 * Decompiled with CFR 0.152.
 */
package org.nlpub.watset.graph;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import org.jgrapht.Graph;
import org.jgrapht.GraphTests;
import org.jgrapht.Graphs;
import org.jgrapht.alg.interfaces.ClusteringAlgorithm;
import org.nlpub.watset.graph.ClusteringAlgorithmBuilder;
import org.nlpub.watset.graph.NodeWeighting;
import org.nlpub.watset.graph.NodeWeightings;
import org.nlpub.watset.util.Maximizer;

public class ChineseWhispers<V, E>
implements ClusteringAlgorithm<V> {
    protected final Graph<V, E> graph;
    protected final NodeWeighting<V, E> weighting;
    protected final int iterations;
    protected final Random random;
    protected ClusteringAlgorithm.Clustering<V> clustering;

    public static <V, E> Builder<V, E> builder() {
        return new Builder();
    }

    public ChineseWhispers(Graph<V, E> graph, NodeWeighting<V, E> weighting, int iterations, Random random) {
        this.graph = GraphTests.requireUndirected(graph);
        this.weighting = Objects.requireNonNull(weighting);
        this.iterations = iterations;
        this.random = Objects.requireNonNull(random);
    }

    @Override
    public ClusteringAlgorithm.Clustering<V> getClustering() {
        if (Objects.isNull(this.clustering)) {
            this.clustering = new Implementation<V, E>(this.graph, this.weighting, this.iterations, this.random).compute();
        }
        return this.clustering;
    }

    protected static class Implementation<V, E> {
        protected final Graph<V, E> graph;
        protected final NodeWeighting<V, E> weighting;
        protected final int iterations;
        protected final Random random;
        protected final Map<V, Integer> labels;
        protected int steps;

        public Implementation(Graph<V, E> graph, NodeWeighting<V, E> weighting, int iterations, Random random) {
            this.graph = graph;
            this.weighting = weighting;
            this.iterations = iterations;
            this.random = random;
            this.labels = new HashMap<V, Integer>(graph.vertexSet().size());
        }

        public ClusteringAlgorithm.Clustering<V> compute() {
            ArrayList<V> nodes = new ArrayList<V>(this.graph.vertexSet());
            int i = 0;
            for (V node : this.graph.vertexSet()) {
                this.labels.put((Integer)node, i++);
            }
            this.steps = 0;
            while (this.steps < this.iterations) {
                Collections.shuffle(nodes, this.random);
                if (this.step(nodes) == 0) break;
                ++this.steps;
            }
            Map<Integer, List<Map.Entry>> groups = this.labels.entrySet().stream().collect(Collectors.groupingBy(Map.Entry::getValue));
            ArrayList clusters = new ArrayList(groups.size());
            for (List<Map.Entry> cluster : groups.values()) {
                clusters.add(cluster.stream().map(Map.Entry::getKey).collect(Collectors.toSet()));
            }
            return new ClusteringAlgorithm.ClusteringImpl(clusters);
        }

        protected int step(List<V> nodes) {
            int changed = 0;
            Iterator<V> iterator = nodes.iterator();
            while (iterator.hasNext()) {
                V node;
                Map<Integer, Double> scores = this.score(node = iterator.next());
                Optional<Map.Entry> label = Maximizer.argrandmax(scores.entrySet(), Map.Entry::getValue, this.random);
                int updated = label.isPresent() ? (Integer)label.get().getKey() : this.labels.get(node);
                int previous = this.labels.put((Integer)node, updated);
                if (previous == updated) continue;
                ++changed;
            }
            return changed;
        }

        protected Map<Integer, Double> score(V node) {
            Set<E> edges = this.graph.edgesOf(node);
            HashMap<Integer, Double> weights = new HashMap<Integer, Double>(edges.size());
            for (E edge : edges) {
                V neighbor = Graphs.getOppositeVertex(this.graph, edge, node);
                int label = this.labels.get(neighbor);
                weights.merge(label, this.weighting.apply(this.graph, this.labels, node, neighbor), Double::sum);
            }
            return weights;
        }

        public int getIterations() {
            return this.iterations;
        }

        public int getSteps() {
            return this.steps;
        }
    }

    public static class Builder<V, E>
    implements ClusteringAlgorithmBuilder<V, E, ChineseWhispers<V, E>> {
        public static final int ITERATIONS = 20;
        private NodeWeighting<V, E> weighting = NodeWeightings.top();
        private int iterations = 20;
        private Random random = new Random();

        @Override
        public ChineseWhispers<V, E> apply(Graph<V, E> graph) {
            return new ChineseWhispers<V, E>(graph, this.weighting, this.iterations, this.random);
        }

        public Builder<V, E> setWeighting(NodeWeighting<V, E> weighting) {
            this.weighting = Objects.requireNonNull(weighting);
            return this;
        }

        public Builder<V, E> setIterations(int iterations) {
            this.iterations = iterations;
            return this;
        }

        public Builder<V, E> setRandom(Random random) {
            this.random = Objects.requireNonNull(random);
            return this;
        }
    }
}

