/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.security.SecureRandom;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;

public class MatrixUtils {
    private MatrixUtils() {
        throw new IllegalStateException("Utility class");
    }

    public static double[][] matrixFromPredictionInput(PredictionInput p) {
        return MatrixUtils.rowVector(p.getFeatures().stream().mapToDouble(f -> f.getValue().asNumber()).toArray());
    }

    public static double[][] matrixFromPredictionInput(List<PredictionInput> ps) {
        return (double[][])ps.stream().map(p -> p.getFeatures().stream().mapToDouble(f -> f.getValue().asNumber()).toArray()).toArray(x$0 -> new double[x$0][]);
    }

    public static double[][] matrixFromPredictionOutput(PredictionOutput p) {
        return MatrixUtils.rowVector(p.getOutputs().stream().mapToDouble(f -> f.getValue().asNumber()).toArray());
    }

    public static double[][] matrixFromPredictionOutput(List<PredictionOutput> ps) {
        return (double[][])ps.stream().map(p -> p.getOutputs().stream().mapToDouble(o -> o.getValue().asNumber()).toArray()).toArray(x$0 -> new double[x$0][]);
    }

    public static double[][] rowVector(double[] v) {
        double[][] out = new double[1][v.length];
        out[0] = v;
        return out;
    }

    public static double[][] columnVector(double[] v) {
        double[][] out = new double[v.length][1];
        for (int i = 0; i < v.length; ++i) {
            out[i][0] = v[i];
        }
        return out;
    }

    public static int[] getShape(double[][] x) {
        int rows = x.length;
        int cols = x[0].length;
        return new int[]{rows, cols};
    }

    public static double[] getCol(double[][] x, int i) {
        int cols = MatrixUtils.getShape(x)[1];
        if (cols <= i || i < 0) {
            throw new IllegalArgumentException(String.format("Column index %d too large, matrix only has %d column(s)", i, cols));
        }
        return Arrays.stream(x).mapToDouble(doubles -> doubles[i]).toArray();
    }

    public static double[][] getCols(double[][] x, List<Integer> idxs) {
        if (idxs.isEmpty()) {
            throw new IllegalArgumentException("Empty column idxs passed to getCols");
        }
        int[] shape = MatrixUtils.getShape(x);
        double[][] out = new double[shape[0]][idxs.size()];
        for (int i = 0; i < shape[0]; ++i) {
            for (int col = 0; col < idxs.size(); ++col) {
                if (idxs.get(col) >= shape[1] || idxs.get(col) < 0) {
                    throw new IllegalArgumentException(String.format("Column index %d output bounds, matrix only has %d column(s)", col, shape[1]));
                }
                out[i][col] = x[i][idxs.get(col)];
            }
        }
        return out;
    }

    public static double[][] matrixSum(double[][] a, double[][] b) {
        int[] bShape;
        int[] aShape = MatrixUtils.getShape(a);
        if (!Arrays.equals(aShape, bShape = MatrixUtils.getShape(b))) {
            throw new IllegalArgumentException("Shape of matrix A must shape of matrix B" + String.format("Matrix A shape:  %d x %d, ", aShape[0], aShape[1]) + String.format("Matrix B shape:  %d x %d,", bShape[0], bShape[1]));
        }
        double[][] sum = new double[aShape[0]][aShape[1]];
        for (int i = 0; i < aShape[0]; ++i) {
            for (int j = 0; j < aShape[1]; ++j) {
                sum[i][j] = a[i][j] + b[i][j];
            }
        }
        return sum;
    }

    public static double[][] matrixRowSum(double[][] a, double[] b) {
        int[] aShape = MatrixUtils.getShape(a);
        double[][] bMat = MatrixUtils.rowVector(b);
        double[][] out = new double[aShape[0]][aShape[1]];
        for (int i = 0; i < aShape[0]; ++i) {
            out[i] = MatrixUtils.matrixSum(MatrixUtils.rowVector(a[i]), bMat)[0];
        }
        return out;
    }

    public static double[][] matrixDifference(double[][] a, double[][] b) {
        double[][] bNeg = MatrixUtils.matrixMultiply(b, -1.0);
        return MatrixUtils.matrixSum(a, bNeg);
    }

    public static double[][] matrixRowDifference(double[][] a, double[] b) {
        double[] bNeg = Arrays.stream(b).map(v -> -v).toArray();
        return MatrixUtils.matrixRowSum(a, bNeg);
    }

    public static double[][] matrixMultiply(double[][] a, double[][] b) {
        int[] bShape;
        int[] aShape = MatrixUtils.getShape(a);
        if (aShape[1] != (bShape = MatrixUtils.getShape(b))[0]) {
            throw new IllegalArgumentException("# columns of matrix A must match # rows of matrix B" + String.format("Matrix A shape:  %d x %d, ", aShape[0], aShape[1]) + String.format("Matrix B shape:  %d x %d,", bShape[0], bShape[1]));
        }
        double[][] product = new double[aShape[0]][bShape[1]];
        for (int i = 0; i < aShape[0]; ++i) {
            for (int j = 0; j < bShape[1]; ++j) {
                for (int k = 0; k < aShape[1]; ++k) {
                    double[] dArray = product[i];
                    int n = j;
                    dArray[n] = dArray[n] + a[i][k] * b[k][j];
                }
            }
        }
        return product;
    }

    public static double[][] matrixMultiply(double[][] a, double b) {
        int[] aShape = MatrixUtils.getShape(a);
        double[][] product = new double[aShape[0]][aShape[1]];
        for (int i = 0; i < aShape[0]; ++i) {
            for (int j = 0; j < aShape[1]; ++j) {
                product[i][j] = a[i][j] * b;
            }
        }
        return product;
    }

    public static double[] sum(double[][] x, Axis axis) {
        int[] shape = MatrixUtils.getShape(x);
        if (axis == Axis.ROW) {
            double[][] out = new double[1][shape[1]];
            for (int i = 0; i < shape[0]; ++i) {
                out = MatrixUtils.matrixSum(out, MatrixUtils.rowVector(x[i]));
            }
            return out[0];
        }
        double[][] out = new double[1][shape[0]];
        for (int i = 0; i < shape[1]; ++i) {
            out = MatrixUtils.matrixSum(out, MatrixUtils.rowVector(MatrixUtils.getCol(x, i)));
        }
        return out[0];
    }

    public static double[][] transpose(double[][] x) {
        int[] shape = MatrixUtils.getShape(x);
        double[][] transposed = new double[shape[1]][shape[0]];
        for (int i = 0; i < shape[0]; ++i) {
            for (int j = 0; j < shape[1]; ++j) {
                transposed[j][i] = x[i][j];
            }
        }
        return transposed;
    }

    private static int findPivot(double[][] x, boolean[] pivotsUsed) {
        double maxAbs = 0.0;
        int pivot = 0;
        int size = MatrixUtils.getShape(x)[0];
        for (int diagIdx = 0; diagIdx < size; ++diagIdx) {
            double abs = Math.abs(x[diagIdx][diagIdx]);
            if (!(abs > maxAbs) || pivotsUsed[diagIdx]) continue;
            pivot = diagIdx;
            maxAbs = abs;
        }
        return pivot;
    }

    private static double[][] invertSquareMatrix(double[][] x, double zeroThreshold) {
        int size = MatrixUtils.getShape(x)[0];
        double[][] copy = new double[size][size];
        for (int i = 0; i < size; ++i) {
            copy[i] = Arrays.copyOf(x[i], size);
        }
        boolean[] pivotsUsed = new boolean[size];
        Arrays.fill(pivotsUsed, false);
        for (int iterations = 0; iterations < size; ++iterations) {
            int pivot = MatrixUtils.findPivot(copy, pivotsUsed);
            double pivotVal = copy[pivot][pivot];
            if (Math.abs(pivotVal) < zeroThreshold) {
                throw new ArithmeticException("Matrix is singular and cannot be inverted");
            }
            copy[pivot][pivot] = 1.0;
            pivotsUsed[pivot] = true;
            int i = 0;
            while (i < size) {
                double[] dArray = copy[pivot];
                int n = i++;
                dArray[n] = dArray[n] / pivotVal;
            }
            for (i = 0; i < size; ++i) {
                if (i == pivot) continue;
                double rowValueAtPivot = copy[i][pivot];
                copy[i][pivot] = 0.0;
                for (int j = 0; j < size; ++j) {
                    double[] dArray = copy[i];
                    int n = j;
                    dArray[n] = dArray[n] - copy[pivot][j] * rowValueAtPivot;
                }
            }
        }
        return copy;
    }

    public static double[][] jitterInvert(double[][] x, int numRetries, double zeroThreshold, Random random) {
        for (int jitterTries = 0; jitterTries < numRetries; ++jitterTries) {
            try {
                double[][] xInv = MatrixUtils.invertSquareMatrix(x, zeroThreshold);
                return xInv;
            }
            catch (ArithmeticException e) {
                MatrixUtils.jitterMatrix(x, 1.0E-8, random);
                continue;
            }
        }
        throw new ArithmeticException("Matrix is singular and could not be inverted via jittering");
    }

    public static double[][] jitterInvert(double[][] x, int numRetries, double zeroThreshold) {
        return MatrixUtils.jitterInvert(x, numRetries, zeroThreshold, new SecureRandom());
    }

    private static void jitterMatrix(double[][] x, double delta, Random random) {
        for (int i = 0; i < x.length; ++i) {
            int j = 0;
            while (j < x[0].length) {
                double[] dArray = x[i];
                int n = j++;
                dArray[n] = dArray[n] + delta * random.nextDouble();
            }
        }
    }

    public static enum Axis {
        ROW,
        COLUMN;

    }
}

