/* 
 * Copyright (C) 2016 Du-Lab Team <dulab.binf@gmail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

package dulab.adap.common.algorithms.machineleanring;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math3.exception.TooManyIterationsException;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.OpenMapRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.StatUtils;

/**
 *
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */


public class MarkovGraphClustering 
{
    private static final double EPS = 1e-3;
    
    /**
     * Finds all attractors in the graph by performing Markov Graph Clustering
     * 
     * @param similarity a square matrix with 1s on the diagonal
     * @param params clustering parameters
     * @return list of attractors
     */
    
    public List <Integer> findAttractors(double[][] similarity, 
            MarkovGraphClusteringParameters params)
            throws TooManyIterationsException
    {
        final int size = similarity.length;
        
        RealMatrix similarityMatrix = new OpenMapRealMatrix(size, size);
        for (int i = 0; i < size; ++i)
            for (int j = 0; j < size; ++j)
                if (similarity[i][j] != 0.0)
                    similarityMatrix.setEntry(i, j, similarity[i][j]);
        
//        RealMatrix similarityMatrix = new Array2DRowRealMatrix(similarity);
        
        scale(similarityMatrix);
        
        for (int i = 0; i < params.maxIteration; ++i)
        {
            RealMatrix copy = similarityMatrix.copy();
            
            similarityMatrix = expand(similarityMatrix);
            inflate(similarityMatrix, params.inflationCoefficient);
            
            double norm = copy.subtract(similarityMatrix).getFrobeniusNorm();
            
            System.out.println(norm);
            
            if (norm < params.tolerance)
                return getAttractors(similarityMatrix);
        }
        
        throw new TooManyIterationsException(params.maxIteration);
    }
    
    /** 
     * Divide each element of a matrix by the sum of the corresponding column
     * 
     * @param matrix a square matrix
     */
    
    private void scale(RealMatrix matrix)
    {
        int size = matrix.getColumnDimension();
        
        final double[] columnSum = new double[size];
        for (int i = 0; i < size; ++i)
            columnSum[i] = matrix.getColumnVector(i).getL1Norm();
        
//        for (int i = 0; i < size; ++i)
//            for (int j = 0; j < size; ++j)
//                matrix.multiplyEntry(i, j, 1.0 / columnSum[j]);
        
        matrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
            @Override
            public double visit(int row, int column, double value) {
                return value / columnSum[column];
            }
        });
    }      
    
    /**
     * Calculate square of a matrix
     * 
     * @param matrix a square matrix
     */
    
    private RealMatrix expand(RealMatrix matrix) {
        return matrix.multiply(matrix);
    }
    
    /**
     * Raise each element of matrix to power r
     * @param matrix a square matrix
     * @param r a number greater then 1
     */
    
    private void inflate(RealMatrix matrix, final double r)
    {
        matrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
            @Override
            public double visit(int row, int column, double value) {
                return Math.pow(value, r);
            }
        });
        
        scale(matrix);
    }
    
    /**
     * Find indices of all rows that have non-zero elements
     * 
     * @param matrix a square matrix
     * @return list of indices
     */
    
    private List <Integer> getAttractors(RealMatrix matrix)
    {
        List <Integer> result = new ArrayList<> ();
        
        int size = matrix.getRowDimension();
        
        for (int i = 0; i < size; ++i)
            if (StatUtils.max(matrix.getRowVector(i).toArray()) > EPS) 
                result.add(i);
        
        return result;
    }
}
