/* 
 * Copyright (C) 2016 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.common.algorithms.machineleanring;

import dulab.adap.common.types.Graph;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixEntry;

import org.apache.commons.math3.exception.TooManyIterationsException;

/**
 *
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */


public class MarkovGraphClusteringV2 
{
    private static final double EPS = 1e-3;
    
    /**
     * Finds all attractors in the graph by performing Markov Graph Clustering
     * 
     * @param similarity a square matrix with 1s on the diagonal
     * @param params clustering parameters
     * @return list of attractors
     */
    
    public double[][] iterate(double[][] similarity, 
            MarkovGraphClusteringParameters params)
            throws TooManyIterationsException
    {
        final int size = similarity.length;
        
//        MyMatrix similarityMatrix = new DenseMatrix(size, size);
//        for (int i = 0; i < size; ++i)
//            for (int j = 0; j < size; ++j)
//                if (similarity[i][j] != 0.0)
//                    similarityMatrix.setEntry(i, j, similarity[i][j]);
        
        Matrix similarityMatrix = new DenseMatrix(similarity);
        
//        PrintWriter writer;
//        
//        try {
//            writer = new PrintWriter("similarity.txt", "UTF-8");
//        } catch (IOException e) {
//            e.printStackTrace();
//            return null;
//        }
//        
//        write(writer, similarityMatrix);
        
        scale(similarityMatrix);
        
//        write(writer, similarityMatrix);
        
        int iteration;
        
        for (iteration = 0; iteration < params.maxIteration; ++iteration)
        {
            Matrix copy = similarityMatrix.copy();
            
            similarityMatrix = expand(similarityMatrix);
            inflate(similarityMatrix, params.inflationCoefficient);
            
//            write(writer, similarityMatrix);
            
            double norm = copy.add(-1.0, similarityMatrix)
                    .norm(Matrix.Norm.Frobenius);
            
//            System.out.println(norm);
            
            if (norm < params.tolerance) break;
        }
        
        if (iteration >= params.maxIteration)
            throw new TooManyIterationsException(params.maxIteration);
        
        for (MatrixEntry e : similarityMatrix)
            similarity[e.row()][e.column()] = e.get();
        
        return similarity;
    }
    
    /** 
     * Divide each element of a matrix by the sum of the corresponding column
     * 
     * @param matrix a square matrix
     */
    
    private void scale(Matrix matrix)
    {
        int size = matrix.numColumns();
        
        double[] columnSum = new double[size];
        for (int i = 0; i < size; ++i)
            for (int j = 0; j < size; ++j)
                columnSum[j] += matrix.get(i, j);
        
        for (int i = 0; i < size; ++i)
            for (int j = 0; j < size; ++j)
                matrix.set(i, j, matrix.get(i, j) / columnSum[j]);
        
//        matrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
//            @Override
//            public double visit(int row, int column, double value) {
//                return value / columnSum[column];
//            }
//        });
    }      
    
    /**
     * Calculate square of a matrix
     * 
     * @param matrix a square matrix
     */
    
    private Matrix expand(Matrix matrix) {
        return matrix.mult(matrix, matrix.copy());
    }
    
    /**
     * Raise each element of matrix to power r
     * @param matrix a square matrix
     * @param r a number greater then 1
     */
    
    private void inflate(Matrix matrix, double r) 
    {
        for (MatrixEntry e : matrix)
            e.set(Math.pow(e.get(), r));
        
//        matrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
//            @Override
//            public double visit(int row, int column, double value) {
//                return Math.pow(value, r);
//            }
//        });
        
        scale(matrix);
    }
    
//    /**
//     * Find indices of all rows that have non-zero elements
//     * 
//     * @param matrix a square matrix
//     * @return list of indices
//     */
//    
//    private List <Integer> getAttractors(MyMatrix matrix)
//    {
//        List <Integer> result = new ArrayList<> ();
//        
//        int size = matrix.numRows();
//        
//        for (int i = 0; i < size; ++i)
//        {
//            double max = 0.0;
//            for (int j = 0; j < size; ++j)
//                if (matrix.get(i, j) > max) max = matrix.get(i, j);
//            
//            if (max > EPS)
//                result.add(i);
//        }
//        
//        return result;
//    }
    
    private void write(PrintWriter writer, Matrix matrix)
    {
        for (int i = 0; i < matrix.numRows(); ++i)
        {
            for (int j = 0; j < matrix.numColumns(); ++j) 
                writer.print(Double.toString(matrix.get(i, j)) + " ");
            
            writer.println();
        }
        
        writer.println();
    }
    
    private List <List <Integer>> getAttractors(Matrix matrix)
    {
        List <List <Integer>> attractors = new ArrayList <> ();
        
        int size = matrix.numRows();
        
        boolean[][] isProcessed = new boolean[size][size];
        
        List <Set <Integer>> columnIndices = new ArrayList <> ();
        
        for (MatrixEntry entry : matrix)
        {   
            int row = entry.row();
            int column = entry.column();
            double value = entry.get();
            
            if (isProcessed[row][column] || value < EPS) continue;
            
            // Find all entries in the column with the same value
            List <Integer> rowIndices = new ArrayList <> ();
            
            for (int i = 0; i < size; ++i)
                if (Math.abs(matrix.get(i, column) - value) < EPS)
                    rowIndices.add(i);

            // For each found row, find columns with the same value
            columnIndices.clear();
            for (int i : rowIndices)
            {
                Set <Integer> indices = new HashSet <> ();

                for (int j = 0; j < size; ++j)
                    if (Math.abs(matrix.get(i, j) - value) < EPS) {
                        indices.add(j);
                        isProcessed[i][j] = true;
                    }

                columnIndices.add(indices);
            }

            // Check if all column indices are the same
            Set <Integer> refIndices = columnIndices.get(0);
            for (Set <Integer> indices : columnIndices)
                if (!refIndices.containsAll(indices)
                        || !indices.containsAll(refIndices))
                {
                    throw new IllegalArgumentException("Cannot process the similarity matrix");
                }

            attractors.add(rowIndices);
        }
        
        return attractors;
    }
    
    /**
     * Find communities in stochastic matrix
     * 
     * @param similarity matrix
     * @return list of communities
     */
    
    public List <Graph <Integer>> getCommunities(final double[][] stochastic)
    {
        int size = stochastic.length;
        
        List <Graph <Integer>> communities = new ArrayList <> ();
        
        boolean[][] isProcessed = new boolean[size][size];
        
        for (int row = 0; row < size; ++row)
            for (int column = 0; column < size; ++column)
            {   
                double value = stochastic[row][column];

                if (isProcessed[row][column] || value < EPS) continue;

                Graph <Integer> community = new Graph <> ();
                
                // Find all entries in the column with the same value
                List <Integer> rowIndices = new ArrayList <> ();

                for (int i = 0; i < size; ++i)
                    if (Math.abs(stochastic[i][column] - value) < EPS)
                        rowIndices.add(i);

                // For each entry, find all entries in the row and add an edge
                // to the community
                for (int i : rowIndices)
                    for (int j = 0; j < size; ++j)
                        if (Math.abs(stochastic[i][j] - value) < EPS) {
                            community.addEdge(i, j);
                            isProcessed[i][j] = true;
                        }

                communities.add(community);
            }
        
        return communities;
    }
    
    public double getModularity(final double[][] similarity, 
            final double[][] stochastic)
    {
        int size = similarity.length;
        
        Integer[] nodes = new Integer[size];
        for (int i = 0; i < size; ++i) nodes[i] = i;
        
        Graph <Integer> graph = new Graph <> (nodes, similarity, EPS);
        
        List <Graph <Integer>> communities = getCommunities(stochastic);
        
        return graph.modularity(communities);
    }
}
