/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

public class WeightedLinearRegressionResults {
    private final RealVector coefficients;
    private final double intercept;
    private final int dof;
    private final double mse;
    private final RealVector stdErrors;
    private final RealVector pvalues;

    public WeightedLinearRegressionResults(RealVector coefficients, boolean intercept, int dof, double mse, RealVector stdErrors, RealVector pvalues) {
        if (intercept) {
            this.coefficients = coefficients.getSubVector(0, coefficients.getDimension() - 1);
            this.intercept = coefficients.getEntry(coefficients.getDimension() - 1);
        } else {
            this.coefficients = coefficients;
            this.intercept = 0.0;
        }
        this.dof = dof;
        this.mse = mse;
        this.stdErrors = stdErrors;
        this.pvalues = pvalues;
    }

    public RealVector predict(RealMatrix x) throws IllegalArgumentException {
        if (x.getColumnDimension() != this.coefficients.getDimension()) {
            throw new IllegalArgumentException(String.format("Num feature mismatch: Number of columns in x (%d)", x.getColumnDimension()) + String.format(" must match number of coefficients (%d)", this.coefficients.getDimension()));
        }
        return x.operate(this.coefficients).mapAdd(this.intercept);
    }

    public RealVector getCoefficients() {
        return this.coefficients;
    }

    public double getIntercept() {
        return this.intercept;
    }

    public double getMSE() {
        return this.mse;
    }

    public RealVector getStdErrors() {
        return this.stdErrors;
    }

    public RealVector getPValues() {
        return this.pvalues;
    }

    public RealVector getConf(double alpha) {
        TDistribution tdist = new TDistribution((double)this.dof);
        double q = tdist.inverseCumulativeProbability(1.0 - alpha / 2.0);
        return this.stdErrors.mapMultiply(q);
    }
}

