/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.util;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.util.Randoms;
import java.text.NumberFormat;
import java.util.Arrays;

public class MVNormal {
    public static double[] cholesky(double[] input, int numRows) {
        double[] result = new double[input.length];
        double sumRowSquared = 0.0;
        double dotProduct = 0.0;
        int rowOffset = 0;
        int colOffset = 0;
        int row = 0;
        while (row < numRows) {
            sumRowSquared = 0.0;
            rowOffset = row * numRows;
            int col = 0;
            while (col < row) {
                dotProduct = 0.0;
                colOffset = col * numRows;
                int i = 0;
                while (i < col) {
                    dotProduct += result[rowOffset + i] * result[colOffset + i];
                    ++i;
                }
                result[rowOffset + col] = (input[rowOffset + col] - dotProduct) / result[colOffset + col];
                sumRowSquared += result[rowOffset + col] * result[rowOffset + col];
                ++col;
            }
            result[rowOffset + row] = Math.sqrt(input[rowOffset + row] - sumRowSquared);
            ++row;
        }
        return result;
    }

    public static double[] bandCholesky(double[] input, int numRows) {
        double[] result = new double[input.length];
        double sumRowSquared = 0.0;
        double dotProduct = 0.0;
        int rowOffset = 0;
        int colOffset = 0;
        int row = 0;
        while (row < numRows) {
            sumRowSquared = 0.0;
            rowOffset = row * numRows;
            int firstNonZero = row;
            int col = 0;
            while (col < row) {
                block6: {
                    block5: {
                        if (firstNonZero != row) break block5;
                        if (input[rowOffset + col] == 0.0) break block6;
                        firstNonZero = col;
                    }
                    dotProduct = 0.0;
                    colOffset = col * numRows;
                    int i = firstNonZero;
                    while (i < col) {
                        dotProduct += result[rowOffset + i] * result[colOffset + i];
                        ++i;
                    }
                    result[rowOffset + col] = (input[rowOffset + col] - dotProduct) / result[colOffset + col];
                    sumRowSquared += result[rowOffset + col] * result[rowOffset + col];
                }
                ++col;
            }
            result[rowOffset + row] = Math.sqrt(input[rowOffset + row] - sumRowSquared);
            ++row;
        }
        return result;
    }

    public static double[] bandMatrixRoot(int dim, int bandwidth) {
        double[] result = new double[dim * dim];
        int row = 0;
        while (row < dim) {
            int rowOffset = row * dim;
            int col = Math.max(0, row - bandwidth + 1);
            while (col <= row) {
                result[rowOffset + col] = 1.0;
                ++col;
            }
            ++row;
        }
        return result;
    }

    public static double[] nextMVNormal(double[] mean, double[] precision, Randoms random) {
        return MVNormal.nextMVNormalWithCholesky(mean, MVNormal.cholesky(precision, mean.length), random);
    }

    public static double[] nextMVNormalWithCholesky(double[] mean, double[] precisionLowerTriangular, Randoms random) {
        int n = mean.length;
        double[] result = new double[n];
        int i = 0;
        while (i < n) {
            result[i] = random.nextGaussian();
            ++i;
        }
        int i2 = n - 1;
        while (i2 >= 0) {
            double innerProduct = 0.0;
            int j = i2 + 1;
            while (j < n) {
                innerProduct += result[j] * precisionLowerTriangular[n * j + i2];
                ++j;
            }
            result[i2] = (result[i2] - innerProduct) / precisionLowerTriangular[n * i2 + i2];
            --i2;
        }
        i2 = 0;
        while (i2 < n) {
            int n2 = i2;
            result[n2] = result[n2] + mean[i2];
            ++i2;
        }
        return result;
    }

    public static double[] nextZeroSumMVNormalWithCholesky(double[] mean, double[] precisionLowerTriangular, Randoms random) {
        int n = mean.length;
        double[] result = MVNormal.nextMVNormalWithCholesky(mean, precisionLowerTriangular, random);
        double sum = 0.0;
        int i = 0;
        while (i < n) {
            sum += result[i];
            ++i;
        }
        double[] ones = new double[n];
        Arrays.fill(ones, 1.0);
        double[] firstSolution = MVNormal.solveWithForwardSubstitution(ones, precisionLowerTriangular);
        double[] rowSums = MVNormal.solveWithBackSubstitution(firstSolution, precisionLowerTriangular);
        double sumOfRowSums = 0.0;
        int i2 = 0;
        while (i2 < n) {
            sumOfRowSums += rowSums[i2];
            ++i2;
        }
        double inverseSumOfRowSums = 1.0 / sumOfRowSums;
        int i3 = 0;
        while (i3 < n) {
            int n2 = i3;
            result[n2] = result[n2] - inverseSumOfRowSums * rowSums[i3] * sum;
            ++i3;
        }
        return result;
    }

    public static double[][] nextMVNormal(int n, double[] mean, double[] precision, Randoms random) {
        double[][] result = new double[n][];
        int i = 0;
        while (i < n) {
            result[i] = MVNormal.nextMVNormal(mean, precision, random);
            ++i;
        }
        return result;
    }

    public static FeatureVector nextFeatureVector(Alphabet alphabet, double[] mean, double[] precision, Randoms random) {
        return new FeatureVector(alphabet, MVNormal.nextMVNormal(mean, precision, random));
    }

    public static double[] nextMVNormalPosterior(double[] priorMean, double[] priorPrecisionDiagonal, double[] precision, double[] observedMean, int observations, Randoms random) {
        int dimension = priorMean.length;
        double[] linearCombination = new double[dimension];
        int i = 0;
        while (i < dimension) {
            linearCombination[i] = priorMean[i] * priorPrecisionDiagonal[i];
            double innerProduct = 0.0;
            int j = 0;
            while (j < dimension) {
                innerProduct += precision[dimension * i + j] * observedMean[j];
                ++j;
            }
            int n = i++;
            linearCombination[n] = linearCombination[n] + (double)observations * innerProduct;
        }
        double[] posteriorPrecision = new double[precision.length];
        int row = 0;
        while (row < dimension) {
            int col = 0;
            while (col < dimension) {
                posteriorPrecision[dimension * row + col] = (double)observations * precision[dimension * row + col];
                if (row == col) {
                    int n = dimension * row + col;
                    posteriorPrecision[n] = posteriorPrecision[n] + priorPrecisionDiagonal[row];
                }
                ++col;
            }
            ++row;
        }
        double[] inversePosteriorPrecision = MVNormal.invertSPD(posteriorPrecision, dimension);
        double[] posteriorMean = new double[dimension];
        int row2 = 0;
        while (row2 < dimension) {
            double innerProduct = 0.0;
            int col = 0;
            while (col < dimension) {
                innerProduct += inversePosteriorPrecision[dimension * row2 + col] * linearCombination[col];
                ++col;
            }
            posteriorMean[row2] = innerProduct;
            ++row2;
        }
        return MVNormal.nextMVNormal(posteriorMean, posteriorPrecision, random);
    }

    public static double[] solveWithBackSubstitution(double[] b, double[] lowerTriangular) {
        int n = b.length;
        double[] result = new double[n];
        int i = n - 1;
        while (i >= 0) {
            double innerProduct = 0.0;
            int j = i + 1;
            while (j < n) {
                innerProduct += result[j] * lowerTriangular[n * j + i];
                ++j;
            }
            result[i] = (b[i] - innerProduct) / lowerTriangular[n * i + i];
            --i;
        }
        return result;
    }

    public static double[] solveWithForwardSubstitution(double[] b, double[] lowerTriangular) {
        int n = b.length;
        double[] result = new double[n];
        int i = 0;
        while (i < n) {
            double innerProduct = 0.0;
            int j = 0;
            while (j < i) {
                innerProduct += result[j] * lowerTriangular[n * i + j];
                ++j;
            }
            result[i] = (b[i] - innerProduct) / lowerTriangular[n * i + i];
            ++i;
        }
        return result;
    }

    public static double[] invertLowerTriangular(double[] inputMatrix, int dimension) {
        double[] outputMatrix = new double[inputMatrix.length];
        int row = 0;
        while (row < dimension) {
            int col = 0;
            while (col <= row) {
                double x = col == row ? 1.0 : 0.0;
                int i = col;
                while (i < row) {
                    x -= inputMatrix[dimension * row + i] * outputMatrix[dimension * i + col];
                    ++i;
                }
                outputMatrix[dimension * row + col] = x / inputMatrix[dimension * row + row];
                ++col;
            }
            ++row;
        }
        return outputMatrix;
    }

    public static double[] lowerTriangularCrossproduct(double[] inputMatrix, int dimension) {
        double[] outputMatrix = new double[inputMatrix.length];
        int row = 0;
        while (row < dimension) {
            int col = row;
            while (col < dimension) {
                double innerProduct = 0.0;
                int i = col;
                while (i < dimension) {
                    innerProduct += inputMatrix[row + dimension * i] * inputMatrix[col + dimension * i];
                    ++i;
                }
                outputMatrix[dimension * row + col] = innerProduct;
                outputMatrix[row + dimension * col] = innerProduct;
                ++col;
            }
            ++row;
        }
        return outputMatrix;
    }

    public static double[] lowerTriangularProduct(double[] leftMatrix, double[] rightMatrix, int dimension) {
        double[] outputMatrix = new double[leftMatrix.length];
        int row = 0;
        while (row < dimension) {
            int col = 0;
            while (col <= row) {
                double innerProduct = 0.0;
                int i = col;
                while (i <= row) {
                    innerProduct += leftMatrix[dimension * row + i] * rightMatrix[dimension * i + col];
                    ++i;
                }
                outputMatrix[dimension * row + col] = innerProduct;
                ++col;
            }
            ++row;
        }
        return outputMatrix;
    }

    public static double[] invertSPD(double[] inputMatrix, int dimension) {
        return MVNormal.lowerTriangularCrossproduct(MVNormal.invertLowerTriangular(MVNormal.bandCholesky(inputMatrix, dimension), dimension), dimension);
    }

    public static double[] nextWishart(double[] sqrtScaleMatrix, int dimension, int degreesOfFreedom, Randoms random) {
        double[] sample = new double[sqrtScaleMatrix.length];
        int row = 0;
        while (row < dimension) {
            int col = 0;
            while (col < row) {
                sample[row * dimension + col] = random.nextGaussian(0.0, 1.0);
                ++col;
            }
            sample[row * dimension + row] = Math.sqrt(random.nextChiSq(degreesOfFreedom));
            ++row;
        }
        System.out.println(MVNormal.diagonalToString(sample, dimension));
        System.out.println(MVNormal.diagonalToString(sqrtScaleMatrix, dimension));
        System.out.println(MVNormal.diagonalToString(MVNormal.lowerTriangularProduct(sample, sqrtScaleMatrix, dimension), dimension));
        return MVNormal.lowerTriangularCrossproduct(MVNormal.lowerTriangularProduct(sample, sqrtScaleMatrix, dimension), dimension);
    }

    public static double[] nextWishartPosterior(double[] scatterMatrix, int observations, double[] priorPrecisionDiagonal, int priorDF, int dimension, Randoms random) {
        double[] scatterPlusPrior = new double[scatterMatrix.length];
        System.arraycopy(scatterMatrix, 0, scatterPlusPrior, 0, scatterMatrix.length);
        int i = 0;
        while (i < dimension) {
            int n = dimension * i + i;
            scatterPlusPrior[n] = scatterPlusPrior[n] + 1.0 / priorPrecisionDiagonal[i];
            ++i;
        }
        System.out.println(" inverted scatter plus prior");
        System.out.println(MVNormal.diagonalToString(MVNormal.invertSPD(scatterPlusPrior, dimension), dimension));
        System.out.println(" chol inverted scatter plus prior");
        System.out.println(MVNormal.diagonalToString(MVNormal.cholesky(MVNormal.invertSPD(scatterPlusPrior, dimension), dimension), dimension));
        double[] sqrtScaleMatrix = MVNormal.cholesky(MVNormal.invertSPD(scatterPlusPrior, dimension), dimension);
        return MVNormal.nextWishart(sqrtScaleMatrix, dimension, observations + priorDF, random);
    }

    public static String doubleArrayToString(double[] matrix, int dimension) {
        NumberFormat formatter = NumberFormat.getInstance();
        formatter.setMaximumFractionDigits(10);
        StringBuffer output = new StringBuffer();
        int row = 0;
        while (row < dimension) {
            int col = 0;
            while (col < dimension) {
                output.append(formatter.format(matrix[dimension * row + col]));
                output.append("\t");
                ++col;
            }
            output.append("\n");
            ++row;
        }
        return output.toString();
    }

    public static String diagonalToString(double[] matrix, int dimension) {
        NumberFormat formatter = NumberFormat.getInstance();
        formatter.setMaximumFractionDigits(4);
        StringBuffer output = new StringBuffer();
        int row = 0;
        while (row < dimension) {
            output.append(formatter.format(matrix[dimension * row + row]));
            output.append(" ");
            ++row;
        }
        return output.toString();
    }

    public static double[] getScatterMatrix(double[][] observationMatrix) {
        int observations = observationMatrix.length;
        int dimension = observationMatrix[0].length;
        double[] outputMatrix = new double[dimension * dimension];
        double[] means = new double[dimension];
        int i = 0;
        while (i < observations) {
            int d = 0;
            while (d < dimension) {
                int n = d;
                means[n] = means[n] + observationMatrix[i][d];
                ++d;
            }
            ++i;
        }
        int d = 0;
        while (d < dimension) {
            int n = d++;
            means[n] = means[n] / (double)observations;
        }
        i = 0;
        while (i < observations) {
            int d1 = 0;
            while (d1 < dimension) {
                int d2 = 0;
                while (d2 < dimension) {
                    int n = dimension * d1 + d2;
                    outputMatrix[n] = outputMatrix[n] + (observationMatrix[i][d1] - means[d1]) * (observationMatrix[i][d2] - means[d2]);
                    ++d2;
                }
                ++d1;
            }
            ++i;
        }
        return outputMatrix;
    }

    public static void testCholesky() {
        int observations = 1000;
        double[] mean = new double[20];
        double[] precisionMatrix = new double[400];
        int i = 0;
        while (i < 20) {
            precisionMatrix[20 * i + i] = 1.0;
            ++i;
        }
        Randoms random = new Randoms();
        double[] scatterMatrix = MVNormal.getScatterMatrix(MVNormal.nextMVNormal(observations, mean, precisionMatrix, random));
        double[] priorPrecision = new double[20];
        Arrays.fill(priorPrecision, 1.0);
        MVNormal.nextWishartPosterior(scatterMatrix, observations, priorPrecision, 21, 20, random);
    }

    public static void main(String[] args) {
        int i;
        double[] sample;
        double[] spd = new double[]{3.0, 0.0, -1.0, 0.0, 3.0, 0.0, -1.0, 0.0, 3.0};
        Randoms random = new Randoms();
        double[] mean = new double[]{1.0, 1.0, 1.0};
        double[] lower = MVNormal.cholesky(spd, 3);
        int iter = 0;
        while (iter < 10) {
            sample = MVNormal.nextMVNormalWithCholesky(mean, lower, random);
            i = 0;
            while (i < sample.length) {
                System.out.print(String.valueOf(sample[i]) + "\t");
                ++i;
            }
            System.out.println();
            ++iter;
        }
        iter = 0;
        while (iter < 10) {
            sample = MVNormal.nextZeroSumMVNormalWithCholesky(mean, lower, random);
            i = 0;
            while (i < sample.length) {
                System.out.print(String.valueOf(sample[i]) + "\t");
                ++i;
            }
            System.out.println();
            ++iter;
        }
    }
}

