/*
 * Copyright (C) 2017 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.common.algorithms.machineleanring;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.exception.MaxCountExceededException;

import javax.annotation.Nonnull;
import java.util.*;

/**
 * This class performs K-Medoids clustering
 *
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */
public class KMedoidsClustering
{
    private class ArgSortItem {
        public int index;
        public double value;
    }

    private final double[][] similarityMatrix;
    private final double[] weights;
    private final int numberOfClusters;
    private final int maxIterations;
    private final int size;

    /**
     * Constructor without weights
     *
     * @param similarityMatrix symmetric square matrix of similarities
     *                         with values from 0.0 (no similarity) to 1.0 (best similarity)
     *
     * @param numberOfClusters number of clusters
     * @param maxIterations maximum number of iterations
     * @throws IllegalArgumentException if number of cluster is greater than size of the similarity matrix
     */

    public KMedoidsClustering(@Nonnull double[][] similarityMatrix,
                              int numberOfClusters, int maxIterations)
            throws IllegalArgumentException
    {
        if (similarityMatrix.length < numberOfClusters)
            throw new IllegalArgumentException("Cannot split " + similarityMatrix.length + " points into "
                    + numberOfClusters + " clusters");

        this.similarityMatrix = similarityMatrix;
        this.numberOfClusters = numberOfClusters;
        this.maxIterations = maxIterations;
        this.size = similarityMatrix.length;

        this.weights = new double[this.size];
        for (int i = 0; i < this.size; ++i) this.weights[i] = 1.0;
    }

    /**
     * Constructor with weights
     *
     * @param similarityMatrix symmetric square matrix of similarities
     *                         with values from 0.0 (no similarity) to 1.0 (best similarity)
     *
     * @param weights array of weights
     * @param numberOfClusters number of clusters
     * @param maxIterations maximum number of iterations
     * @throws IllegalArgumentException if number of cluster is greater than size of the similarity matrix
     */

    public KMedoidsClustering(@Nonnull double[][] similarityMatrix, @Nonnull double[] weights,
                              int numberOfClusters, int maxIterations)
            throws IllegalArgumentException
    {
        this(similarityMatrix, numberOfClusters, maxIterations);
        for (int i = 0; i < this.size; ++i) this.weights[i] = weights[i];
    }

    /**
     * Performs K_Medoids clustering of data points
     *
     * Step 1. Initialize: select random data points as medoids
     * Step 2. Associate each data item to the closest medoid
     * Step 3. While the cost of the configuration increases, choose each item as a medoid and
     *         recompute the cost (sum of distances of points to their medoids). Choose medoids with the highest cost.
     *         Stop if no better medoids can be found.
     *
     * @return labels ffor the data points
     * @throws MaxCountExceededException if number of iterations exceeds maxIteration
     */

    public int[] execute()
            throws MaxCountExceededException
    {
        // Initialize medoids, labels, cost, and count
        int[] medoids = initializeMedoids();
        int[] labels = assign(medoids);
        double prevCost = getCost(labels, medoids);
        int count = 0;

        while (true)
        {
            int[] newLabels = null;
            int[] newMedoids = null;

            for (int i = 0; i < size; ++i)
            {
                if (i == medoids[labels[i]]) continue;

                // Make i-th item a medoid
                int[] tryMedoids = medoids.clone();
                tryMedoids[labels[i]] = i;
                int[] tryLabels = assign(tryMedoids);

                // Check the cost
                double cost = getCost(tryLabels, tryMedoids);
                if (cost > prevCost) {
                    prevCost = cost;
                    newLabels = tryLabels;
                    newMedoids = tryMedoids;
                }
            }

            if (newLabels == null || newMedoids == null)
                // The best medoids are found
                break;
            else {
                // Reassign medoids and labels and try again
                labels = newLabels;
                medoids = newMedoids;
            }

            if (++count >= maxIterations)
                throw new MaxCountExceededException(maxIterations);
        }

        return labels;
    }

    /**
     * Assigns labels to each data item. Labels correspond to the closest medoid.
     *
     * @param medoids indices of medoids
     * @return labels where 0 corresponds to the first medoid, 1 to the second medoid, etc.
     */

    private int[] assign(int[] medoids)
    {
        int[] labels = new int[size];
        for (int i = 0; i < size; ++i)
        {
            double maxSimilarity = -Double.MAX_VALUE;
            for (int j = 0; j < medoids.length; ++j) {
                double similarity = similarityMatrix[i][medoids[j]];
                if (similarity > maxSimilarity) {
                    maxSimilarity = similarity;
                    labels[i] = j;
                }
            }
        }
        return labels;
    }

    /**
     * Calculates cost of clustering as the sum of similarities of data points to their medoids
     *
     * @param labels labels of data points
     * @param medoids indices of medoids
     * @return cost of clustering
     */

    private double getCost(int[] labels, int[] medoids) {
        double cost = 0.0;
        for (int i = 0; i < size; ++i)
            cost += similarityMatrix[i][medoids[labels[i]]] * weights[i];
        return cost;
    }

    /**
     * Pick numberOfClusters first data points that have the highest sum similarity to every other points
     * @return
     */

    private int[] initializeMedoids()
    {
        double[] sums = new double[size];
        for (int i = 0; i < size; ++i)
            for (int j = 0; j < size; ++j)
                sums[i] += similarityMatrix[i][j];

        List<ArgSortItem> lstSums = new ArrayList<>(size);
        for (int i = 0; i < size; ++i) {
            ArgSortItem item = new ArgSortItem();
            item.index = i;
            item.value = sums[i];
            lstSums.add(item);
        }

        Collections.sort(lstSums, new Comparator<ArgSortItem>() {
            @Override
            public int compare(ArgSortItem o1, ArgSortItem o2) {
                return -Double.compare(o1.value, o2.value);
            }
        });

        int[] medoids = new int[numberOfClusters];
        for (int i = 0; i < numberOfClusters; ++i)
            medoids[i] = lstSums.get(i).index;

        return medoids;
    }
}
