/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.utils.MatrixUtils;

class MatrixUtilsTest {
    private static final double[][] matOneElem = new double[][]{{5.0}};
    private static final double[][] matRowVector = new double[][]{{5.0, 6.0, 7.0}};
    private static final double[][] matColVector = new double[][]{{5.0}, {6.0}, {7.0}};
    private static final double[][] vectorProdRowCol = new double[][]{{110.0}};
    private static final double[][] vectorProdColRow = new double[][]{{25.0, 30.0, 35.0}, {30.0, 36.0, 42.0}, {35.0, 42.0, 49.0}};
    private static final double[][] mat4X3 = new double[][]{{1.0, 2.0, 3.0}, {10.0, 5.0, -3.0}, {14.0, -6.6, 7.0}, {0.0, 5.0, -3.0}};
    private static final double[][] mat3X4 = new double[][]{{1.0, 10.0, 14.0, 0.0}, {2.0, 5.0, -6.6, 5.0}, {3.0, -3.0, 7.0, -3.0}};
    private static final double[][] mat3X5 = new double[][]{{1.0, 10.0, 3.0, -4.0, 0.0}, {10.0, 5.0, -3.0, 3.7, 1.0}, {14.0, -6.6, 7.0, 14.0, 3.0}};
    private static final double[][] mat43X35Product = new double[][]{{63.0, 0.2, 18.0, 45.4, 11.0}, {18.0, 144.8, -6.0, -63.5, -4.0}, {46.0, 60.8, 110.8, 17.58, 14.4}, {8.0, 44.8, -36.0, -23.5, -4.0}};
    private static final double[][] matSquareNonSingular = new double[][]{{1.0, 2.0, 3.0}, {10.0, 5.0, -3.0}, {14.0, -6.6, 7.0}};
    private static final double[][] matSNSInv = new double[][]{{-0.02464332, 0.05479896, 0.03404669}, {0.18158236, 0.05674449, -0.05350195}, {0.22049287, -0.05609598, 0.02431907}};
    private static final double[][] matSquareSingular = new double[][]{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
    private static final double[][] identity = new double[][]{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}};

    MatrixUtilsTest() {
    }

    @Test
    void testShape() {
        int[] shape = MatrixUtils.getShape((double[][])mat3X5);
        Assertions.assertArrayEquals((int[])new int[]{3, 5}, (int[])shape);
    }

    @Test
    void testGetCol() {
        double[] col = MatrixUtils.getCol((double[][])mat3X4, (int)1);
        Assertions.assertArrayEquals((double[])col, (double[])new double[]{10.0, 5.0, -3.0});
    }

    @Test
    void testGetColTooBig() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.getCol((double[][])mat3X4, (int)10));
    }

    @Test
    void testOneElemTranspose() {
        double[][] matOneElemTranspose = MatrixUtils.transpose((double[][])matOneElem);
        for (int i = 0; i < matOneElemTranspose.length; ++i) {
            Assertions.assertArrayEquals((double[])matOneElemTranspose[i], (double[])matOneElem[i]);
        }
    }

    @Test
    void testVectorTranspose() {
        double[][] matRowVectorTranspose = MatrixUtils.transpose((double[][])matRowVector);
        for (int i = 0; i < matRowVectorTranspose.length; ++i) {
            Assertions.assertArrayEquals((double[])matRowVectorTranspose[i], (double[])matColVector[i]);
        }
    }

    @Test
    void testMatrixTranspose() {
        double[][] mat3X4Transpose = MatrixUtils.transpose((double[][])mat3X4);
        for (int i = 0; i < mat3X4Transpose.length; ++i) {
            Assertions.assertArrayEquals((double[])mat3X4Transpose[i], (double[])mat4X3[i]);
        }
    }

    @Test
    void testMatMulNormal() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])mat4X3, (double[][])mat3X5);
        for (int i = 0; i < prod.length; ++i) {
            Assertions.assertArrayEquals((double[])mat43X35Product[i], (double[])prod[i], (double)1.0E-6);
        }
    }

    @Test
    void testMatMulWrongShape() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.matrixMultiply((double[][])mat3X4, (double[][])mat3X5));
    }

    @Test
    void testVectorRowColMultiply() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])matRowVector, (double[][])matColVector);
        for (int i = 0; i < prod.length; ++i) {
            Assertions.assertArrayEquals((double[])vectorProdRowCol[i], (double[])prod[i], (double)1.0E-6);
        }
    }

    @Test
    void testVectorColRowMultiply() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])matColVector, (double[][])matRowVector);
        for (int i = 0; i < prod.length; ++i) {
            Assertions.assertArrayEquals((double[])vectorProdColRow[i], (double[])prod[i], (double)1.0E-6);
        }
    }

    @Test
    void testInvertNormal() {
        double[][] inv = MatrixUtils.invertSquareMatrix((double[][])matSquareNonSingular, (double)1.0E-9);
        for (int i = 0; i < inv.length; ++i) {
            Assertions.assertArrayEquals((double[])matSNSInv[i], (double[])inv[i], (double)1.0E-6);
        }
    }

    @Test
    void testInvertSingular() {
        Assertions.assertThrows(ArithmeticException.class, () -> MatrixUtils.invertSquareMatrix((double[][])matSquareSingular, (double)1.0E-9));
    }

    @Test
    void testJitterInvert() {
        for (int run = 0; run < 100; ++run) {
            Random random = new Random();
            random.setSeed(run);
            double[][] inv = MatrixUtils.jitterInvert((double[][])matSquareSingular, (int)10, (double)1.0E-9, (Random)random);
            double[][] prod = MatrixUtils.matrixMultiply((double[][])matSquareSingular, (double[][])inv);
            for (int i = 0; i < prod.length; ++i) {
                Assertions.assertArrayEquals((double[])prod[i], (double[])identity[i], (double)1.0E-6);
            }
        }
    }
}

