/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Random;
import org.kie.kogito.explainability.utils.MatrixUtils;
import org.kie.kogito.explainability.utils.WeightedLinearRegressionResults;

public class WeightedLinearRegression {
    private WeightedLinearRegression() {
        throw new IllegalStateException("Utility class");
    }

    public static WeightedLinearRegressionResults fit(double[][] features, double[] observations, double[] sampleWeights, boolean intercept, Random random) throws IllegalArgumentException, ArithmeticException {
        int nfeatures = intercept ? features[0].length + 1 : features[0].length;
        int nsamples = observations.length;
        if (features.length != nsamples) {
            throw new IllegalArgumentException(String.format("Num sample mismatch: Number of rows in the features (%d)", features.length) + String.format(" must match number of observations (%d)", nsamples));
        }
        double[][] adjustedFeatures = WeightedLinearRegression.adjustFeatureMatrix(features, intercept);
        double[][] x = new double[nfeatures][nfeatures];
        double[][] b = new double[nfeatures][1];
        for (int i = 0; i < nfeatures; ++i) {
            b[i][0] = 0.0;
            for (int ii = 0; ii < nfeatures; ++ii) {
                x[i][ii] = 0.0;
                for (int j = 0; j < nsamples; ++j) {
                    double[] dArray = x[i];
                    int n = ii;
                    dArray[n] = dArray[n] + sampleWeights[j] * adjustedFeatures[j][i] * adjustedFeatures[j][ii];
                    if (ii != 0) continue;
                    double[] dArray2 = b[i];
                    dArray2[0] = dArray2[0] + sampleWeights[j] * adjustedFeatures[j][i] * observations[j];
                }
            }
        }
        try {
            x = MatrixUtils.jitterInvert(x, 10, 1.0E-9, random);
        }
        catch (ArithmeticException e) {
            throw new ArithmeticException("Weighted Linear Regression: Matrix cannot be inverted! This can be caused by a very under-specified model, where the ratio of samples to features is roughly less than 0.10. This model has a ratio of " + (double)nsamples / (double)nfeatures + ".");
        }
        double[][] coefficients = MatrixUtils.matrixMultiply(x, b);
        double gof = WeightedLinearRegression.getGoodnessOfFit(adjustedFeatures, observations, sampleWeights, coefficients);
        double mse = WeightedLinearRegression.getMSE(adjustedFeatures, observations, sampleWeights, coefficients);
        return new WeightedLinearRegressionResults(coefficients, intercept, gof, mse);
    }

    private static double[][] adjustFeatureMatrix(double[][] features, boolean intercept) {
        int nsamples = features.length;
        int nfeatures = intercept ? features[0].length + 1 : features[0].length;
        double[][] adjustedFeatures = new double[nsamples][nfeatures];
        for (int i = 0; i < nsamples; ++i) {
            if (intercept) {
                System.arraycopy(features[i], 0, adjustedFeatures[i], 0, nfeatures - 1);
                adjustedFeatures[i][nfeatures - 1] = 1.0;
                continue;
            }
            System.arraycopy(features[i], 0, adjustedFeatures[i], 0, nfeatures);
        }
        return adjustedFeatures;
    }

    private static double getGoodnessOfFit(double[][] features, double[] observations, double[] sampleWeights, double[][] coefficients) {
        int nfeatures = features[0].length;
        int nsamples = observations.length;
        double yBar = 0.0;
        double weightSum = 0.0;
        for (int i = 0; i < nsamples; ++i) {
            yBar += sampleWeights[i] * observations[i];
            weightSum += sampleWeights[i];
        }
        if (weightSum == 0.0) {
            throw new ArithmeticException("Weights cannot sum to zero!");
        }
        yBar /= weightSum;
        double totalSquareSum = 0.0;
        double residualSquareSum = 0.0;
        for (int i = 0; i < nsamples; ++i) {
            double fI = 0.0;
            for (int j = 0; j < nfeatures; ++j) {
                fI += features[i][j] * coefficients[j][0];
            }
            double residual = observations[i] - fI;
            double variance = observations[i] - yBar;
            totalSquareSum += sampleWeights[i] * (variance * variance);
            residualSquareSum += sampleWeights[i] * (residual * residual);
        }
        if (totalSquareSum == 0.0) {
            throw new ArithmeticException("Total variance of observations is zero. Use more samples to correct this error");
        }
        return 1.0 - residualSquareSum / totalSquareSum;
    }

    private static double getMSE(double[][] features, double[] observations, double[] sampleWeights, double[][] coefficients) {
        int nfeatures = features[0].length;
        int nsamples = observations.length;
        double totalResidual = 0.0;
        double weightSum = 0.0;
        for (int i = 0; i < nsamples; ++i) {
            double fI = 0.0;
            for (int j = 0; j < nfeatures; ++j) {
                fI += features[i][j] * coefficients[j][0];
            }
            double residual = observations[i] - fI;
            totalResidual += sampleWeights[i] * (residual * residual);
            weightSum += sampleWeights[i];
        }
        if (weightSum == 0.0) {
            throw new ArithmeticException("Weights cannot sum to zero!");
        }
        return totalResidual / weightSum;
    }
}

