/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.utils.WeightedLinearRegressionResults;

class WeightedLinearRegressionResultsTest {
    RealVector coefficients = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});
    RealVector flatCoef = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});
    RealVector stdErrs = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});

    WeightedLinearRegressionResultsTest() {
    }

    @Test
    void testWLRResultsNoIntercept() {
        RealVector coefficients = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});
        RealVector flatCoef = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});
        RealVector stdErrs = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});
        RealVector pvalues = MatrixUtils.createRealVector((double[])new double[4]);
        WeightedLinearRegressionResults wlrr = new WeightedLinearRegressionResults(coefficients, false, 1, 0.01, stdErrs, pvalues);
        Assertions.assertArrayEquals((double[])flatCoef.toArray(), (double[])wlrr.getCoefficients().toArray());
        Assertions.assertArrayEquals((double[])stdErrs.toArray(), (double[])wlrr.getStdErrors().toArray());
        Assertions.assertArrayEquals((double[])pvalues.toArray(), (double[])wlrr.getPValues().toArray());
        Assertions.assertEquals((double)0.0, (double)wlrr.getIntercept());
        Assertions.assertEquals((double)0.01, (double)wlrr.getMSE());
    }

    @Test
    void testWLRResultWithIntercept() {
        RealVector coefficients = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});
        RealVector flatCoef = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0});
        RealVector stdErrs = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0});
        RealVector pvalues = MatrixUtils.createRealVector((double[])new double[4]);
        WeightedLinearRegressionResults wlrr = new WeightedLinearRegressionResults(coefficients, true, 1, 0.01, stdErrs, pvalues);
        Assertions.assertArrayEquals((double[])flatCoef.toArray(), (double[])wlrr.getCoefficients().toArray());
        Assertions.assertArrayEquals((double[])stdErrs.toArray(), (double[])wlrr.getStdErrors().toArray());
        Assertions.assertEquals((double)3.0, (double)wlrr.getIntercept());
        Assertions.assertEquals((double)0.01, (double)wlrr.getMSE());
    }

    @Test
    void testPredictions() {
        RealVector coefficients = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0, 5.0});
        RealVector stdErrs = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0, 5.0});
        RealMatrix x = MatrixUtils.createRealMatrix((double[][])new double[][]{{1.0, 5.0, 3.0, -2.0}, {10.0, -1.0, 0.0, 4.0}, {-2.0, 7.5, 6.0, -3.3}});
        RealVector pvalues = MatrixUtils.createRealVector((double[])new double[5]);
        RealVector y = MatrixUtils.createRealVector((double[])new double[]{6.0, 66.0, -13.4});
        WeightedLinearRegressionResults wlrr = new WeightedLinearRegressionResults(coefficients, true, 1, 0.01, stdErrs, pvalues);
        Assertions.assertArrayEquals((double[])y.toArray(), (double[])wlrr.predict(x).toArray(), (double)1.0E-6);
    }

    @Test
    void testPredictionsWrongNumFeatures() {
        RealVector coefficients = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0, 5.0});
        RealVector stdErrs = MatrixUtils.createRealVector((double[])new double[]{5.0, 1.0, -1.0, 3.0, 5.0});
        RealVector pvalues = MatrixUtils.createRealVector((double[])new double[5]);
        RealMatrix x = MatrixUtils.createRealMatrix((double[][])new double[][]{{1.0, 5.0}, {10.0, -1.0}, {-2.0, 7.5}});
        WeightedLinearRegressionResults wlrr = new WeightedLinearRegressionResults(coefficients, true, 1, 0.01, stdErrs, pvalues);
        Assertions.assertThrows(IllegalArgumentException.class, () -> wlrr.predict(x));
    }
}

