package dulab.adap.workflow.decomposition;

import org.dulab.javanmf.algorithms.Constraint;
import org.dulab.javanmf.updaterules.MUpdateRule;
import org.ejml.data.DMatrixRMaj;

public class ComponentConstraint implements Constraint {

    @Override
    public DMatrixRMaj apply(DMatrixRMaj wt) {
        makeUnimodalRows(wt);
        flattenLastRow(wt);
        scaleRows(wt);
        return wt;
    }

    /**
     * Sets all values in the last row to 1.0
     * @param m matrix
     */
    private void flattenLastRow(DMatrixRMaj m) {
        int i = m.numRows - 1;
        for (int j = 0; j < m.numCols; ++j)
            m.unsafe_set(i, j, 1.0);
    }

    /**
     * Scales all row to have maximum equal to 1.0
     * @param m matrix
     */
    private void scaleRows(DMatrixRMaj m) {

        for (int i = 0; i < m.numRows; ++i) {

            int maxIndex = getRowArgMax(m, i);
            double maxValue = m.unsafe_get(i, maxIndex);

            if (maxValue <= 0.0) continue;

            for (int j = 0; j < m.numCols; ++j) {
                m.unsafe_set(i, j, m.get(i, j) / maxValue);
            }
        }
    }

    /**
     * Modifies each row to have a single local maximum
     * @param m matrix
     */
    private void makeUnimodalRows(DMatrixRMaj m) {

        for (int i = 0; i < m.numRows; ++i) {

            int maxIndex = getRowArgMax(m, i);

            for (int j = maxIndex - 1; j >= 0; --j)
                m.unsafe_set(i, j, Math.min(m.unsafe_get(i, j), m.unsafe_get(i, j + 1)));

            for (int j = maxIndex + 1; j < m.numCols; ++j)
                m.unsafe_set(i, j, Math.min(m.unsafe_get(i, j), m.unsafe_get(i, j - 1)));
        }
    }

    /**
     * Returns the index of maximum of values in a given row
     * @param m matrix
     * @param row index of a row
     * @return the index of maximum value
     */
    private int getRowArgMax(DMatrixRMaj m, int row) {

        double maxValue = -Double.MAX_VALUE;
        int maxIndex = -1;
        for (int j = 0; j < m.numCols; ++j)
            if (m.unsafe_get(row, j) > maxValue) {
                maxValue = m.unsafe_get(row, j);
                maxIndex = j;
            }

        if (maxIndex < 0)
            throw new IllegalStateException("Cannot find maximum value");

        return maxIndex;
    }
}
