/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.MatrixUtils;

class MatrixUtilsTest {
    double[][] matOneElem = new double[][]{{5.0}};
    double[] vector = new double[]{5.0, 6.0, 7.0};
    double[][] matRowVector = new double[][]{{5.0, 6.0, 7.0}};
    double[][] matColVector = new double[][]{{5.0}, {6.0}, {7.0}};
    double[][] vectorProdRowCol = new double[][]{{110.0}};
    double[][] vectorProdColRow = new double[][]{{25.0, 30.0, 35.0}, {30.0, 36.0, 42.0}, {35.0, 42.0, 49.0}};
    double[][] mat4X3 = new double[][]{{1.0, 2.0, 3.0}, {10.0, 5.0, -3.0}, {14.0, -6.6, 7.0}, {0.0, 5.0, -3.0}};
    double[][] mat3X4 = new double[][]{{1.0, 10.0, 14.0, 0.0}, {2.0, 5.0, -6.6, 5.0}, {3.0, -3.0, 7.0, -3.0}};
    double[][] mat3X5 = new double[][]{{1.0, 10.0, 3.0, -4.0, 0.0}, {10.0, 5.0, -3.0, 3.7, 1.0}, {14.0, -6.6, 7.0, 14.0, 3.0}};
    double[][] mat3X5get013 = new double[][]{{1.0, 10.0, -4.0}, {10.0, 5.0, 3.7}, {14.0, -6.6, 14.0}};
    double[][] mat3X5get03130 = new double[][]{{1.0, -4.0, 10.0, -4.0, 1.0}, {10.0, 3.7, 5.0, 3.7, 10.0}, {14.0, 14.0, -6.6, 14.0, 14.0}};
    double[][] mat43X35Product = new double[][]{{63.0, 0.2, 18.0, 45.4, 11.0}, {18.0, 144.8, -6.0, -63.5, -4.0}, {46.0, 60.8, 110.8, 17.58, 14.4}, {8.0, 44.8, -36.0, -23.5, -4.0}};
    double[][] matSquareNonSingular = new double[][]{{1.0, 2.0, 3.0}, {10.0, 5.0, -3.0}, {14.0, -6.6, 7.0}};
    double[][] matSNSInv = new double[][]{{-0.02464332, 0.05479896, 0.03404669}, {0.18158236, 0.05674449, -0.05350195}, {0.22049287, -0.05609598, 0.02431907}};
    double[][] matSquareSingular = new double[][]{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
    double[][] identity = new double[][]{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}};
    double[][] matIdentityPlusVector = new double[][]{{6.0, 6.0, 7.0}, {5.0, 7.0, 7.0}, {5.0, 6.0, 8.0}};
    double[][] mssPlusIdentity = new double[][]{{2.0, 2.0, 3.0}, {4.0, 6.0, 6.0}, {7.0, 8.0, 10.0}};
    double[][] mssMinusIdentity = new double[][]{{0.0, 2.0, 3.0}, {4.0, 4.0, 6.0}, {7.0, 8.0, 8.0}};
    double[] mssSumRow = new double[]{12.0, 15.0, 18.0};
    double[] mssSumCol = new double[]{6.0, 15.0, 24.0};
    Random rn = new Random();

    MatrixUtilsTest() {
    }

    @Test
    void testPICreation() {
        ArrayList<Feature> fs = new ArrayList<Feature>();
        for (int j = 0; j < 5; ++j) {
            fs.add(FeatureFactory.newNumericalFeature((String)"f", (Number)this.mat3X5[0][j]));
        }
        PredictionInput pi = new PredictionInput(fs);
        double[][] converted = MatrixUtils.matrixFromPredictionInput((PredictionInput)pi);
        Assertions.assertArrayEquals((double[])this.mat3X5[0], (double[])converted[0]);
    }

    @Test
    void testPIListCreation() {
        ArrayList<PredictionInput> ps = new ArrayList<PredictionInput>();
        for (int i = 0; i < 3; ++i) {
            ArrayList<Feature> fs = new ArrayList<Feature>();
            for (int j = 0; j < 5; ++j) {
                fs.add(FeatureFactory.newNumericalFeature((String)"f", (Number)this.mat3X5[i][j]));
            }
            ps.add(new PredictionInput(fs));
        }
        double[][] converted = MatrixUtils.matrixFromPredictionInput(ps);
        Assertions.assertArrayEquals((Object[])this.mat3X5, (Object[])converted);
    }

    @Test
    void testPOCreation() {
        ArrayList<Output> os = new ArrayList<Output>();
        for (int j = 0; j < 5; ++j) {
            Value v = new Value((Object)this.mat3X5[0][j]);
            os.add(new Output("o", Type.NUMBER, v, 0.0));
        }
        PredictionOutput po = new PredictionOutput(os);
        double[][] converted = MatrixUtils.matrixFromPredictionOutput((PredictionOutput)po);
        Assertions.assertArrayEquals((double[])this.mat3X5[0], (double[])converted[0]);
    }

    @Test
    void testPOListCreation() {
        ArrayList<PredictionOutput> ps = new ArrayList<PredictionOutput>();
        for (int i = 0; i < 3; ++i) {
            ArrayList<Output> os = new ArrayList<Output>();
            for (int j = 0; j < 5; ++j) {
                Value v = new Value((Object)this.mat3X5[i][j]);
                os.add(new Output("o", Type.NUMBER, v, 0.0));
            }
            ps.add(new PredictionOutput(os));
        }
        double[][] converted = MatrixUtils.matrixFromPredictionOutput(ps);
        Assertions.assertArrayEquals((Object[])this.mat3X5, (Object[])converted);
    }

    @Test
    void testRowVectorCreation() {
        double[][] converted = MatrixUtils.rowVector((double[])this.vector);
        for (int i = 0; i < converted.length; ++i) {
            Assertions.assertEquals((double)converted[0][i], (double)this.vector[i]);
        }
    }

    @Test
    void testColVectorCreation() {
        double[][] converted = MatrixUtils.columnVector((double[])this.vector);
        for (int i = 0; i < converted.length; ++i) {
            Assertions.assertEquals((double)converted[i][0], (double)this.vector[i]);
        }
    }

    @Test
    void testShape() {
        int[] shape = MatrixUtils.getShape((double[][])this.mat3X5);
        Assertions.assertArrayEquals((int[])new int[]{3, 5}, (int[])shape);
    }

    @Test
    void testGetColTooBig() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.getCol((double[][])this.mat3X4, (int)10));
    }

    @Test
    void testGetNegCol() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.getCol((double[][])this.mat3X4, (int)-10));
    }

    @Test
    void testGetCol() {
        double[] col = MatrixUtils.getCol((double[][])this.mat3X4, (int)1);
        Assertions.assertArrayEquals((double[])col, (double[])new double[]{10.0, 5.0, -3.0});
    }

    @Test
    void testGetCols() {
        double[][] output = MatrixUtils.getCols((double[][])this.mat3X5, List.of(Integer.valueOf(0), Integer.valueOf(1), Integer.valueOf(3)));
        for (int i = 0; i < this.mat3X5get013.length; ++i) {
            Assertions.assertArrayEquals((double[])this.mat3X5get013[i], (double[])output[i]);
        }
    }

    @Test
    void testGetDupCols() {
        double[][] output = MatrixUtils.getCols((double[][])this.mat3X5, List.of(Integer.valueOf(0), Integer.valueOf(3), Integer.valueOf(1), Integer.valueOf(3), Integer.valueOf(0)));
        for (int i = 0; i < this.mat3X5get03130.length; ++i) {
            Assertions.assertArrayEquals((double[])this.mat3X5get03130[i], (double[])output[i]);
        }
    }

    @Test
    void testGetColsTooBig() {
        List<Integer> testIdxs = List.of(Integer.valueOf(0), Integer.valueOf(6));
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.getCols((double[][])this.mat3X5, (List)testIdxs));
    }

    @Test
    void testGetNegCols() {
        List<Integer> testIdxs = List.of(Integer.valueOf(0), Integer.valueOf(-6));
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.getCols((double[][])this.mat3X5, (List)testIdxs));
    }

    @Test
    void testGetNoCols() {
        List testIdxs = List.of();
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.getCols((double[][])this.mat3X5, (List)testIdxs));
    }

    @Test
    void testOneElemTranspose() {
        double[][] matOneElemTranspose = MatrixUtils.transpose((double[][])this.matOneElem);
        for (int i = 0; i < matOneElemTranspose.length; ++i) {
            Assertions.assertArrayEquals((double[])matOneElemTranspose[i], (double[])this.matOneElem[i]);
        }
    }

    @Test
    void testVectorTranspose() {
        double[][] matRowVectorTranspose = MatrixUtils.transpose((double[][])this.matRowVector);
        for (int i = 0; i < matRowVectorTranspose.length; ++i) {
            Assertions.assertArrayEquals((double[])matRowVectorTranspose[i], (double[])this.matColVector[i]);
        }
    }

    @Test
    void testMatrixTranspose() {
        double[][] mat3X4Transpose = MatrixUtils.transpose((double[][])this.mat3X4);
        for (int i = 0; i < mat3X4Transpose.length; ++i) {
            Assertions.assertArrayEquals((double[])mat3X4Transpose[i], (double[])this.mat4X3[i]);
        }
    }

    @Test
    void testIntraSumRow() {
        double[] sum = MatrixUtils.sum((double[][])this.matSquareSingular, (MatrixUtils.Axis)MatrixUtils.Axis.ROW);
        Assertions.assertArrayEquals((double[])this.mssSumRow, (double[])sum, (double)1.0E-6);
    }

    @Test
    void testIntraSumCol() {
        double[] sum = MatrixUtils.sum((double[][])this.matSquareSingular, (MatrixUtils.Axis)MatrixUtils.Axis.COLUMN);
        Assertions.assertArrayEquals((double[])this.mssSumCol, (double[])sum, (double)1.0E-6);
    }

    @Test
    void testMatSum() {
        double[][] sum = MatrixUtils.matrixSum((double[][])this.matSquareSingular, (double[][])this.identity);
        for (int i = 0; i < sum.length; ++i) {
            Assertions.assertArrayEquals((double[])this.mssPlusIdentity[i], (double[])sum[i], (double)1.0E-6);
        }
    }

    @Test
    void testMatSumWrongSizes() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.matrixSum((double[][])this.matSquareSingular, (double[][])this.mat4X3));
    }

    @Test
    void testMatDiff() {
        double[][] diff = MatrixUtils.matrixDifference((double[][])this.matSquareSingular, (double[][])this.identity);
        for (int i = 0; i < diff.length; ++i) {
            Assertions.assertArrayEquals((double[])this.mssMinusIdentity[i], (double[])diff[i], (double)1.0E-6);
        }
    }

    @Test
    void testMatRowSum() {
        double[][] sum = MatrixUtils.matrixRowSum((double[][])this.identity, (double[])this.vector);
        for (int i = 0; i < sum.length; ++i) {
            Assertions.assertArrayEquals((double[])this.matIdentityPlusVector[i], (double[])sum[i], (double)1.0E-6);
        }
    }

    @Test
    void testMatRowDiff() {
        double[][] diff = MatrixUtils.matrixRowDifference((double[][])this.matIdentityPlusVector, (double[])this.vector);
        for (int i = 0; i < diff.length; ++i) {
            Assertions.assertArrayEquals((double[])this.identity[i], (double[])diff[i], (double)1.0E-6);
        }
    }

    @Test
    void testMatMulScalar() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])this.mat4X3, (double)3.0);
        for (int i = 0; i < prod.length; ++i) {
            for (int j = 0; j < prod[0].length; ++j) {
                Assertions.assertEquals((double)(this.mat4X3[i][j] * 3.0), (double)prod[i][j], (double)1.0E-6);
            }
        }
    }

    @Test
    void testMatMulByZero() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])this.mat4X3, (double)0.0);
        for (int i = 0; i < prod.length; ++i) {
            for (int j = 0; j < prod[0].length; ++j) {
                Assertions.assertEquals((double)(this.mat4X3[i][j] * 0.0), (double)prod[i][j], (double)1.0E-6);
            }
        }
    }

    @Test
    void testMatMulNormal() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])this.mat4X3, (double[][])this.mat3X5);
        for (int i = 0; i < prod.length; ++i) {
            Assertions.assertArrayEquals((double[])this.mat43X35Product[i], (double[])prod[i], (double)1.0E-6);
        }
    }

    @Test
    void testMatMulWrongShape() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtils.matrixMultiply((double[][])this.mat3X4, (double[][])this.mat3X5));
    }

    @Test
    void testVectorRowColMultiply() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])this.matRowVector, (double[][])this.matColVector);
        for (int i = 0; i < prod.length; ++i) {
            Assertions.assertArrayEquals((double[])this.vectorProdRowCol[i], (double[])prod[i], (double)1.0E-6);
        }
    }

    @Test
    void testVectorColRowMultiply() {
        double[][] prod = MatrixUtils.matrixMultiply((double[][])this.matColVector, (double[][])this.matRowVector);
        for (int i = 0; i < prod.length; ++i) {
            Assertions.assertArrayEquals((double[])this.vectorProdColRow[i], (double[])prod[i], (double)1.0E-6);
        }
    }

    @Test
    void testInvertNormal() {
        double[][] inv = MatrixUtils.jitterInvert((double[][])this.matSquareNonSingular, (int)1, (double)1.0E-9, (Random)this.rn);
        for (int i = 0; i < inv.length; ++i) {
            Assertions.assertArrayEquals((double[])this.matSNSInv[i], (double[])inv[i], (double)1.0E-6);
        }
    }

    @Test
    void testInvertSingular() {
        Assertions.assertThrows(ArithmeticException.class, () -> MatrixUtils.jitterInvert((double[][])this.matSquareSingular, (int)1, (double)1.0E-9, (Random)this.rn));
    }

    @Test
    void testJitterInvert() {
        for (int run = 0; run < 100; ++run) {
            double[][] inv = MatrixUtils.jitterInvert((double[][])this.matSquareSingular, (int)10, (double)1.0E-9, (Random)this.rn);
            double[][] prod = MatrixUtils.matrixMultiply((double[][])this.matSquareSingular, (double[][])inv);
            for (int i = 0; i < prod.length; ++i) {
                Assertions.assertArrayEquals((double[])prod[i], (double[])this.identity[i], (double)1.0E-6);
            }
        }
    }

    @Test
    void testSecureJitterInvert() {
        for (int run = 0; run < 100; ++run) {
            double[][] inv = MatrixUtils.jitterInvert((double[][])this.matSquareSingular, (int)10, (double)1.0E-9);
            double[][] prod = MatrixUtils.matrixMultiply((double[][])this.matSquareSingular, (double[][])inv);
            for (int i = 0; i < prod.length; ++i) {
                Assertions.assertArrayEquals((double[])prod[i], (double[])this.identity[i], (double)1.0E-6);
            }
        }
    }
}

