/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.Random;
import java.util.stream.IntStream;

public class MatrixUtils {
    private MatrixUtils() {
        throw new IllegalStateException("Utility class");
    }

    static int[] getShape(double[][] x) {
        int rows = x.length;
        int cols = x[0].length;
        return new int[]{rows, cols};
    }

    static double[] getCol(double[][] x, int i) throws IllegalArgumentException {
        int cols = MatrixUtils.getShape(x)[1];
        if (cols <= i) {
            throw new IllegalArgumentException(String.format("Column index %d too large, matrix only has %d column(s)", i, cols));
        }
        return IntStream.range(0, x.length).mapToDouble(rowIdx -> x[rowIdx][i]).toArray();
    }

    static double[][] transpose(double[][] x) {
        int[] shape = MatrixUtils.getShape(x);
        double[][] transposed = new double[shape[1]][shape[0]];
        for (int i = 0; i < shape[0]; ++i) {
            for (int j = 0; j < shape[1]; ++j) {
                transposed[j][i] = x[i][j];
            }
        }
        return transposed;
    }

    static double[][] matrixMultiply(double[][] a, double[][] b) throws IllegalArgumentException {
        int[] bShape;
        int[] aShape = MatrixUtils.getShape(a);
        if (aShape[1] != (bShape = MatrixUtils.getShape(b))[0]) {
            throw new IllegalArgumentException("# columns of matrix A must match # rows of matrix B" + String.format("Matrix A shape:  %d x %d, ", aShape[0], aShape[1]) + String.format("Matrix B shape:  %d x %d,", bShape[0], bShape[1]));
        }
        double[][] product = new double[aShape[0]][bShape[1]];
        for (int i = 0; i < aShape[0]; ++i) {
            for (int j = 0; j < bShape[1]; ++j) {
                for (int k = 0; k < aShape[1]; ++k) {
                    double[] dArray = product[i];
                    int n = j;
                    dArray[n] = dArray[n] + a[i][k] * b[k][j];
                }
            }
        }
        return product;
    }

    private static int findPivot(double[][] x, boolean[] pivotsUsed) {
        double maxAbs = 0.0;
        int pivot = 0;
        int size = MatrixUtils.getShape(x)[0];
        for (int diagIdx = 0; diagIdx < size; ++diagIdx) {
            double abs = Math.abs(x[diagIdx][diagIdx]);
            if (!(abs > maxAbs) || pivotsUsed[diagIdx]) continue;
            pivot = diagIdx;
            maxAbs = abs;
        }
        return pivot;
    }

    public static double[][] invertSquareMatrix(double[][] x, double zeroThreshold) {
        int size = MatrixUtils.getShape(x)[0];
        double[][] copy = new double[size][size];
        for (int i = 0; i < size; ++i) {
            copy[i] = Arrays.copyOf(x[i], size);
        }
        boolean[] pivotsUsed = new boolean[size];
        Arrays.fill(pivotsUsed, false);
        for (int iterations = 0; iterations < size; ++iterations) {
            int pivot = MatrixUtils.findPivot(copy, pivotsUsed);
            double pivotVal = copy[pivot][pivot];
            if (Math.abs(pivotVal) < zeroThreshold) {
                throw new ArithmeticException("Matrix is singular and cannot be inverted");
            }
            copy[pivot][pivot] = 1.0;
            pivotsUsed[pivot] = true;
            int i = 0;
            while (i < size) {
                double[] dArray = copy[pivot];
                int n = i++;
                dArray[n] = dArray[n] / pivotVal;
            }
            for (i = 0; i < size; ++i) {
                if (i == pivot) continue;
                double rowValueAtPivot = copy[i][pivot];
                copy[i][pivot] = 0.0;
                for (int j = 0; j < size; ++j) {
                    double[] dArray = copy[i];
                    int n = j;
                    dArray[n] = dArray[n] - copy[pivot][j] * rowValueAtPivot;
                }
            }
        }
        return copy;
    }

    public static double[][] jitterInvert(double[][] x, int numRetries, double zeroThreshold, Random random) {
        for (int jitterTries = 0; jitterTries < numRetries; ++jitterTries) {
            try {
                double[][] xInv = MatrixUtils.invertSquareMatrix(x, zeroThreshold);
                return xInv;
            }
            catch (ArithmeticException e) {
                MatrixUtils.jitterMatrix(x, 1.0E-8, random);
                continue;
            }
        }
        throw new ArithmeticException("Matrix is singular and could not be inverted via jittering");
    }

    private static void jitterMatrix(double[][] x, double delta, Random random) {
        for (int i = 0; i < x.length; ++i) {
            int j = 0;
            while (j < x[0].length) {
                double[] dArray = x[i];
                int n = j++;
                dArray[n] = dArray[n] + delta * random.nextDouble();
            }
        }
    }
}

