/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import org.apache.commons.math3.distribution.TDistribution;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;

public class WeightedLinearRegressionResults {
    private final double[] coefficients;
    private final double intercept;
    private final int dof;
    private final double mse;
    private final double[] stdErrors;
    private final double[] pvalues;

    public WeightedLinearRegressionResults(double[][] coefficients, boolean intercept, int dof, double mse, double[] stdErrors, double[] pvalues) {
        if (intercept) {
            double[] rawCoeffs = MatrixUtilsExtensions.getCol(coefficients, 0);
            this.coefficients = Arrays.stream(rawCoeffs, 0, rawCoeffs.length - 1).toArray();
            this.intercept = rawCoeffs[rawCoeffs.length - 1];
        } else {
            this.coefficients = MatrixUtilsExtensions.getCol(coefficients, 0);
            this.intercept = 0.0;
        }
        this.dof = dof;
        this.mse = mse;
        this.stdErrors = stdErrors;
        this.pvalues = pvalues;
    }

    public double[] predict(double[][] x) throws IllegalArgumentException {
        if (x[0].length != this.coefficients.length) {
            throw new IllegalArgumentException(String.format("Num feature mismatch: Number of columns in x (%d)", x[0].length) + String.format(" must match number of coefficients (%d)", this.coefficients.length));
        }
        double[] y = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            y[i] = this.intercept;
            for (int j = 0; j < this.coefficients.length; ++j) {
                int n = i;
                y[n] = y[n] + x[i][j] * this.coefficients[j];
            }
        }
        return y;
    }

    public double[] getCoefficients() {
        return this.coefficients;
    }

    public double getIntercept() {
        return this.intercept;
    }

    public double getMSE() {
        return this.mse;
    }

    public double[] getStdErrors() {
        return this.stdErrors;
    }

    public double[] getPValues() {
        return this.pvalues;
    }

    public double[] getConf(double alpha) {
        TDistribution tdist = new TDistribution((double)this.dof);
        double q = tdist.inverseCumulativeProbability(1.0 - alpha / 2.0);
        return Arrays.stream(this.stdErrors).map(x -> q * x).toArray();
    }
}

