package dulab.adap.common.algorithms;

import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

import javax.annotation.Nonnull;
import java.lang.Math;

/**
 * This class is a collection of methods used on DoubleMatrix class frim jBlas package
 * @author Du-Lab Team <dulab.binf@gmail.com>
 */
public class MatrixUtils
{
    /**
     * Finds standard deviation for each column in the matrix and then divides each column by the corresponding
     * standard deviations
     * @param m matrix
     */
    public static void standardizeColumns(DoubleMatrix m) {
        for (int j = 0; j < m.columns; ++j)
        {
            double mean = 0.0;
            double mean2 = 0.0;
            for (int i = 0; i < m.rows; ++i) {
                double v = m.get(i, j);
                mean += v;
                mean2 += v * v;
            }
            mean /= m.rows;
            mean2 /= m.rows;

            double std = java.lang.Math.sqrt(mean2 - mean * mean);

            for (int i = 0; i < m.rows; ++i)
                m.put(i, j, std > 0.0 ? m.get(i, j) / std : 1.0);
        }
    }

    /**
     * Calculates dot-product of a row from the first matrix by a column of the second matrix
     * @param m1 first matrix
     * @param m2 second matrix
     * @param row index of a row in the first matrix
     * @param col index of a column in the second matrix
     * @return dot-product of two vectors
     */
    public static double multiplyRowByColumn(@Nonnull DoubleMatrix m1, @Nonnull DoubleMatrix m2, int row, int col) {
        final int length = m1.columns;
        if (length != m2.rows)
            throw new IncompatibleDimensionsException();

        double sum = 0.0;
        for (int i = 0; i < length; ++i)
            sum += m1.get(row, i) * m2.get(i, col);

        return sum;
    }

    /**
     * Calculates dot-product of a column from the first matrix by a column from the second matrix
     * @param m1 first matrix
     * @param m2 second matrix
     * @param col1 index of a column from the first matrix
     * @param col2 index of a column from the second matrix
     * @return dot-product of two vectors
     */
    public static double multiplyColumnByColumn(@Nonnull DoubleMatrix m1, @Nonnull DoubleMatrix m2, int col1, int col2) {
        final int length = m1.rows;
        if (length != m2.rows)
            throw new IncompatibleDimensionsException();

        double sum = 0.0;
        for (int i = 0; i < length; ++i)
            sum += m1.get(i, col1) * m2.get(i, col2);

        return sum;
    }

    /**
     * Calculates the difference of two matrices and finds its l2-norm
     * @param m1 the first matrix
     * @param m2 the second matrix
     * @return l-2 norm of the difference
     */
    public static double differenceNorm2(@Nonnull DoubleMatrix m1, @Nonnull DoubleMatrix m2) {
        final int numColumns = m1.columns;
        final int numRows = m1.rows;

        if (numColumns != m2.columns || numRows != m2.rows)
            throw new IncompatibleDimensionsException();

        double sum = 0.0;
        for (int i = 0; i < numRows; ++i)
            for (int j = 0; j < numColumns; ++j) {
                double v = m1.get(i, j) - m2.get(i, j);
                sum += v * v;
            }

        return java.lang.Math.sqrt(sum);
    }

    /**
     * Calculates l2-norm of positive elements of a column
     * @param m matrix
     * @param col index of a column
     * @return l2-norm of positive elements of the vector
     */
    public static double columnPositiveNorm2(@Nonnull DoubleMatrix m, int col) {
        double sum = 0.0;
        for (int i = 0; i < m.rows; ++i) {
            double v = m.get(i, col);
            if (v > 0.0)
                sum += v * v;
        }
        return java.lang.Math.sqrt(sum);
    }

    /**
     * Calculates l2-norm of negative elements of a column
     * @param m matrix
     * @param col index of a column
     * @return l2-norm of negative elements of the vector
     */
    public static double columnNegativeNorm2(@Nonnull DoubleMatrix m, int col) {
        double sum = 0.0;
        for (int i = 0; i < m.rows; ++i) {
            double v = m.get(i, col);
            if (v < 0.0)
                sum += v * v;
        }
        return Math.sqrt(sum);
    }

    public static DoubleMatrix columnStandardDeviations(DoubleMatrix m) {
        DoubleMatrix mean = m.columnMeans();
        DoubleMatrix mean2 = m.mul(m).columnMeans();
        return MatrixFunctions.sqrt(mean2.sub(mean.muli(mean)));
    }

    // ----------------------
    // ----- Exceptions -----
    // ----------------------

    /**
     * Exception that is raised when dimensions of two matrices are incompatible
     */
    public static class IncompatibleDimensionsException extends IllegalArgumentException {
        public IncompatibleDimensionsException() {
            super("Incompatible dimensions of the matrices");
        }
    }
}
