package com.github.chen0040.glm.solvers;

import Jama.Matrix;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.maths.Mean;
import com.github.chen0040.glm.maths.StdDev;
import com.github.chen0040.glm.maths.Variance;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.glm.search.LocalSearch;

/* loaded from: input_file:com/github/chen0040/glm/solvers/GlmAlgorithmIrls.class */
public class GlmAlgorithmIrls extends GlmAlgorithm {
    private static final double EPSILON = 1.0E-20d;
    private Matrix A;
    private Matrix b;
    private Matrix At;

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public void copy(GlmAlgorithm glmAlgorithm) {
        super.copy(glmAlgorithm);
        GlmAlgorithmIrls glmAlgorithmIrls = (GlmAlgorithmIrls) glmAlgorithm;
        this.A = glmAlgorithmIrls.A == null ? null : (Matrix) glmAlgorithmIrls.A.clone();
        this.b = glmAlgorithmIrls.b == null ? null : (Matrix) glmAlgorithmIrls.b.clone();
        this.At = glmAlgorithmIrls.At == null ? null : (Matrix) glmAlgorithmIrls.At.clone();
    }

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public GlmAlgorithm makeCopy() {
        GlmAlgorithmIrls glmAlgorithmIrls = new GlmAlgorithmIrls();
        glmAlgorithmIrls.copy(this);
        return glmAlgorithmIrls;
    }

    public GlmAlgorithmIrls() {
    }

    public GlmAlgorithmIrls(GlmDistributionFamily glmDistributionFamily, LinkFunction linkFunction, double[][] dArr, double[] dArr2) {
        super(glmDistributionFamily, linkFunction, (double[][]) null, (double[]) null, (LocalSearch) null);
        this.A = toMatrix(dArr);
        this.b = columnVector(dArr2);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(dArr[0].length, dArr2.length);
    }

    public GlmAlgorithmIrls(GlmDistributionFamily glmDistributionFamily, double[][] dArr, double[] dArr2) {
        super(glmDistributionFamily);
        this.A = toMatrix(dArr);
        this.b = columnVector(dArr2);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(dArr[0].length, dArr2.length);
    }

    private static Matrix toMatrix(double[][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        Matrix matrix = new Matrix(length, length2);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                matrix.set(i, i2, (float) dArr[i][i2]);
            }
        }
        return matrix;
    }

    private static Matrix columnVector(double[] dArr) {
        int length = dArr.length;
        Matrix matrix = new Matrix(length, 1);
        for (int i = 0; i < length; i++) {
            matrix.set(i, 0, dArr[i]);
        }
        return matrix;
    }

    private static Matrix columnVector(int i) {
        return new Matrix(i, 1);
    }

    private static Matrix identity(int i) {
        Matrix matrix = new Matrix(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            matrix.set(i2, i2, 1.0d);
        }
        return matrix;
    }

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public double[] solve() {
        int rowDimension = this.A.getRowDimension();
        int columnDimension = this.A.getColumnDimension();
        Matrix columnVector = columnVector(columnDimension);
        Matrix matrix = null;
        Matrix matrix2 = null;
        for (int i = 0; i < this.maxIters; i++) {
            Matrix times = this.A.times(columnVector);
            Matrix columnVector2 = columnVector(rowDimension);
            double[] dArr = new double[rowDimension];
            double[] dArr2 = new double[rowDimension];
            for (int i2 = 0; i2 < rowDimension; i2++) {
                dArr[i2] = this.linkFunc.GetInvLink(times.get(i2, 0));
                dArr2[i2] = this.linkFunc.GetInvLinkDerivative(times.get(i2, 0));
                columnVector2.set(i2, 0, times.get(i2, 0) + ((this.b.get(i2, 0) - dArr[i2]) / dArr2[i2]));
            }
            matrix = identity(rowDimension);
            for (int i3 = 0; i3 < rowDimension; i3++) {
                double variance = getVariance(dArr[i3]);
                if (variance == 0.0d) {
                    variance = 1.0E-20d;
                }
                matrix.set(i3, i3, (dArr2[i3] * dArr2[i3]) / variance);
            }
            Matrix matrix3 = columnVector;
            Matrix times2 = this.At.times(matrix);
            matrix2 = times2.times(this.A).inverse();
            columnVector = matrix2.times(times2).times(columnVector2);
            if (columnVector.minus(matrix3).norm2() < this.mTol) {
                break;
            }
        }
        this.glmCoefficients = new double[columnDimension];
        for (int i4 = 0; i4 < columnDimension; i4++) {
            this.glmCoefficients[i4] = columnVector.get(i4, 0);
        }
        updateStatistics(matrix2, matrix);
        return this.glmCoefficients;
    }

    private void updateStatistics(Matrix matrix, Matrix matrix2) {
        int rowDimension = matrix.getRowDimension();
        int rowDimension2 = this.b.getRowDimension();
        double[] standardErrors = this.mStats.getStandardErrors();
        double[][] vCovMatrix = this.mStats.getVCovMatrix();
        double[] residuals = this.mStats.getResiduals();
        for (int i = 0; i < rowDimension; i++) {
            standardErrors[i] = Math.sqrt(matrix.get(i, i));
            for (int i2 = 0; i2 < rowDimension; i2++) {
                vCovMatrix[i][i2] = matrix.get(i, i2);
            }
        }
        double[] dArr = new double[rowDimension2];
        for (int i3 = 0; i3 < rowDimension2; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < rowDimension; i4++) {
                d += this.A.get(i3, i4) * this.glmCoefficients[i4];
            }
            residuals[i3] = this.b.get(i3, 0) - this.linkFunc.GetInvLink(d);
            dArr[i3] = this.b.get(i3, 0);
        }
        this.mStats.setResidualStdDev(StdDev.apply(this.mStats.getResiduals(), 0.0d));
        this.mStats.setResponseMean(Mean.apply(dArr));
        this.mStats.setResponseVariance(Variance.apply(dArr, this.mStats.getResponseMean()));
        this.mStats.setR2(1.0d - ((this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev()) / this.mStats.getResponseVariance()));
        this.mStats.setAdjustedR2(1.0d - ((((this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev()) / this.mStats.getResponseVariance()) * (rowDimension - 1)) / ((rowDimension - this.glmCoefficients.length) - 1)));
    }
}
