/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;

class MatrixUtilsExtensionsTest {
    double[][] mat3X5 = new double[][]{{1.0, 10.0, 3.0, -4.0, 0.0}, {10.0, 5.0, -3.0, 3.7, 1.0}, {14.0, -6.6, 7.0, 14.0, 3.0}};
    double[][] mat3X5get013 = new double[][]{{1.0, 10.0, -4.0}, {10.0, 5.0, 3.7}, {14.0, -6.6, 14.0}};
    double[][] mat3X5get03130 = new double[][]{{1.0, -4.0, 10.0, -4.0, 1.0}, {10.0, 3.7, 5.0, 3.7, 10.0}, {14.0, 14.0, -6.6, 14.0, 14.0}};
    double[][] matSquareNonSingular = new double[][]{{1.0, 2.0, 3.0}, {10.0, 5.0, -3.0}, {14.0, -6.6, 7.0}};
    double[][] matSNSInv = new double[][]{{-0.02464332, 0.05479896, 0.03404669}, {0.18158236, 0.05674449, -0.05350195}, {0.22049287, -0.05609598, 0.02431907}};
    double[][] matSquareSingular = new double[][]{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
    double[] mssSumRow = new double[]{12.0, 15.0, 18.0};
    RealVector v = MatrixUtils.createRealVector((double[])new double[]{1.0, 2.0, 3.0});
    RealMatrix mssMatrix = MatrixUtils.createRealMatrix((double[][])this.matSquareSingular);
    RealMatrix rowDiffResult = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.0, 0.0, 0.0}, {3.0, 3.0, 3.0}, {6.0, 6.0, 6.0}});
    RealMatrix colDiffResult = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.0, 1.0, 2.0}, {2.0, 3.0, 4.0}, {4.0, 5.0, 6.0}});
    RealMatrix swapResult = MatrixUtils.createRealMatrix((double[][])new double[][]{{7.0, 8.0, 9.0}, {4.0, 5.0, 6.0}, {1.0, 2.0, 3.0}});
    RealVector swapResultV = MatrixUtils.createRealVector((double[])new double[]{3.0, 2.0, 1.0});
    RealMatrix dotInput1 = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}});
    RealMatrix dotInput2 = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}});
    RealMatrix dotResult = MatrixUtils.createRealMatrix((double[][])new double[][]{{15.0, 18.0, 21.0}, {42.0, 54.0, 66.0}});
    RealVector vMix = MatrixUtils.createRealVector((double[])new double[]{-3.0, -2.0, -1.0, 1.0, 2.0, 3.0});
    RealVector allNeg = MatrixUtils.createRealVector((double[])new double[]{-3.0, -2.0, -1.0});
    RealVector varInput = MatrixUtils.createRealVector((double[])new double[]{0.0, 4.0, 16.0, 2.0, -128.0, -4.0});

    MatrixUtilsExtensionsTest() {
    }

    @Test
    void testPICreation() {
        ArrayList<Feature> fs = new ArrayList<Feature>();
        for (int j = 0; j < 5; ++j) {
            fs.add(FeatureFactory.newNumericalFeature((String)"f", (Number)this.mat3X5[0][j]));
        }
        PredictionInput pi = new PredictionInput(fs);
        RealVector converted = MatrixUtilsExtensions.vectorFromPredictionInput((PredictionInput)pi);
        Assertions.assertArrayEquals((double[])this.mat3X5[0], (double[])converted.toArray());
    }

    @Test
    void testPIListCreation() {
        ArrayList<PredictionInput> ps = new ArrayList<PredictionInput>();
        for (int i = 0; i < 3; ++i) {
            ArrayList<Feature> fs = new ArrayList<Feature>();
            for (int j = 0; j < 5; ++j) {
                fs.add(FeatureFactory.newNumericalFeature((String)"f", (Number)this.mat3X5[i][j]));
            }
            ps.add(new PredictionInput(fs));
        }
        RealMatrix converted = MatrixUtilsExtensions.matrixFromPredictionInput(ps);
        Assertions.assertArrayEquals((Object[])this.mat3X5, (Object[])converted.getData());
    }

    @Test
    void testPOCreation() {
        ArrayList<Output> os = new ArrayList<Output>();
        for (int j = 0; j < 5; ++j) {
            Value v = new Value((Object)this.mat3X5[0][j]);
            os.add(new Output("o", Type.NUMBER, v, 0.0));
        }
        PredictionOutput po = new PredictionOutput(os);
        RealVector converted = MatrixUtilsExtensions.vectorFromPredictionOutput((PredictionOutput)po);
        Assertions.assertArrayEquals((double[])this.mat3X5[0], (double[])converted.toArray());
    }

    @Test
    void testPOListCreation() {
        ArrayList<PredictionOutput> ps = new ArrayList<PredictionOutput>();
        for (int i = 0; i < 3; ++i) {
            ArrayList<Output> os = new ArrayList<Output>();
            for (int j = 0; j < 5; ++j) {
                Value v = new Value((Object)this.mat3X5[i][j]);
                os.add(new Output("o", Type.NUMBER, v, 0.0));
            }
            ps.add(new PredictionOutput(os));
        }
        RealMatrix converted = MatrixUtilsExtensions.matrixFromPredictionOutput(ps);
        Assertions.assertArrayEquals((Object[])this.mat3X5, (Object[])converted.getData());
    }

    @Test
    void testGetCols() {
        RealMatrix output = MatrixUtilsExtensions.getCols((RealMatrix)MatrixUtils.createRealMatrix((double[][])this.mat3X5), List.of(Integer.valueOf(0), Integer.valueOf(1), Integer.valueOf(3)));
        for (int i = 0; i < this.mat3X5get013.length; ++i) {
            Assertions.assertArrayEquals((double[])this.mat3X5get013[i], (double[])output.getRow(i));
        }
    }

    @Test
    void testGetDupCols() {
        RealMatrix output = MatrixUtilsExtensions.getCols((RealMatrix)MatrixUtils.createRealMatrix((double[][])this.mat3X5), List.of(Integer.valueOf(0), Integer.valueOf(3), Integer.valueOf(1), Integer.valueOf(3), Integer.valueOf(0)));
        for (int i = 0; i < this.mat3X5get03130.length; ++i) {
            Assertions.assertArrayEquals((double[])this.mat3X5get03130[i], (double[])output.getRow(i));
        }
    }

    @Test
    void testGetColsTooBig() {
        List<Integer> testIdxs = List.of(Integer.valueOf(0), Integer.valueOf(6));
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtilsExtensions.getCols((RealMatrix)MatrixUtils.createRealMatrix((double[][])this.mat3X5), (List)testIdxs));
    }

    @Test
    void testGetNegCols() {
        List<Integer> testIdxs = List.of(Integer.valueOf(0), Integer.valueOf(-6));
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtilsExtensions.getCols((RealMatrix)MatrixUtils.createRealMatrix((double[][])this.mat3X5), (List)testIdxs));
    }

    @Test
    void testGetNoCols() {
        List testIdxs = List.of();
        Assertions.assertThrows(IllegalArgumentException.class, () -> MatrixUtilsExtensions.getCols((RealMatrix)MatrixUtils.createRealMatrix((double[][])this.mat3X5), (List)testIdxs));
    }

    @Test
    void testIntraSumRow() {
        RealVector sum = MatrixUtilsExtensions.rowSum((RealMatrix)this.mssMatrix);
        Assertions.assertArrayEquals((double[])this.mssSumRow, (double[])sum.toArray(), (double)1.0E-6);
    }

    @Test
    void rowSum() {
        Assertions.assertArrayEquals((double[])new double[]{12.0, 15.0, 18.0}, (double[])MatrixUtilsExtensions.rowSum((RealMatrix)this.mssMatrix).toArray());
    }

    @Test
    void rowSquareSum() {
        Assertions.assertArrayEquals((double[])new double[]{66.0, 93.0, 126.0}, (double[])MatrixUtilsExtensions.rowSquareSum((RealMatrix)this.mssMatrix).toArray());
    }

    @Test
    void rowDifference() {
        Assertions.assertEquals((Object)this.rowDiffResult, (Object)MatrixUtilsExtensions.vectorDifference((RealMatrix)this.mssMatrix, (RealVector)this.v, (MatrixUtilsExtensions.Axis)MatrixUtilsExtensions.Axis.ROW));
    }

    @Test
    void colDifference() {
        Assertions.assertEquals((Object)this.colDiffResult, (Object)MatrixUtilsExtensions.vectorDifference((RealMatrix)this.mssMatrix, (RealVector)this.v, (MatrixUtilsExtensions.Axis)MatrixUtilsExtensions.Axis.COLUMN));
    }

    @Test
    void matrixDot() {
        Assertions.assertEquals((Object)this.dotResult, (Object)MatrixUtilsExtensions.matrixDot((RealMatrix)this.dotInput1, (RealMatrix)this.dotInput2));
    }

    @Test
    void testInvertNormal() {
        RealMatrix inv = MatrixUtilsExtensions.safeInvert((RealMatrix)MatrixUtils.createRealMatrix((double[][])this.matSquareNonSingular));
        for (int i = 0; i < inv.getRowDimension(); ++i) {
            Assertions.assertArrayEquals((double[])this.matSNSInv[i], (double[])inv.getRow(i), (double)1.0E-4);
        }
    }

    @Test
    void testInvertSingular() {
        RealMatrix orig = this.mssMatrix;
        RealMatrix inv = MatrixUtilsExtensions.safeInvert((RealMatrix)orig);
        RealMatrix invProperty = orig.multiply(inv).multiply(orig);
        for (int i = 0; i < inv.getRowDimension(); ++i) {
            Assertions.assertArrayEquals((double[])orig.getRow(i), (double[])invProperty.getRow(i), (double)1.0E-4);
        }
    }

    @Test
    void testMinPos() {
        Assertions.assertEquals((double)1.0, (double)MatrixUtilsExtensions.minPos((RealVector)this.vMix), (double)1.0E-4);
    }

    @Test
    void testMinPosNoNeg() {
        Assertions.assertEquals((double)Double.MAX_VALUE, (double)MatrixUtilsExtensions.minPos((RealVector)this.allNeg), (double)1.0E-4);
    }

    @Test
    void testVar() {
        Assertions.assertEquals((double)2443.22222222, (double)MatrixUtilsExtensions.variance((RealVector)this.varInput), (double)1.0E-4);
    }

    @Test
    void testSwapRealMatrix() {
        RealMatrix mCopy = this.mssMatrix.copy();
        MatrixUtilsExtensions.swap((RealMatrix)mCopy, (int)0, (int)2);
        Assertions.assertEquals((Object)this.swapResult, (Object)mCopy);
    }

    @Test
    void testSwapRealVector() {
        RealVector vCopy = this.v.copy();
        MatrixUtilsExtensions.swap((RealVector)vCopy, (int)0, (int)2);
        Assertions.assertEquals((Object)this.swapResultV, (Object)vCopy);
    }

    @Test
    void testSwapIntArr() {
        int[] arr = new int[]{1, 2, 3};
        MatrixUtilsExtensions.swap((int[])arr, (int)0, (int)2);
        Assertions.assertArrayEquals((int[])new int[]{3, 2, 1}, (int[])arr);
    }
}

