/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.List;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.utils.LarsPath;
import org.kie.kogito.explainability.utils.LarsPathResults;
import org.kie.kogito.explainability.utils.WeightedLinearRegression;

class LarsPathTest {
    RealMatrix X = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.92966881, 0.17435502, 0.86274567, 0.02096693, 0.61729408, 0.27663037, 0.07324771, 0.86299396, 0.20387837, 0.2678897}, {0.46124402, 0.21212798, 0.54547663, 0.85310364, 0.23584478, 0.89939373, 0.90052444, 0.48947526, 0.97695481, 0.31682039}, {0.66084177, 0.54153099, 0.76965712, 0.08213559, 0.9262654, 0.68282777, 0.500637, 0.76781516, 0.14606141, 0.53844816}, {0.44602165, 0.72739983, 0.66221962, 0.20234917, 0.80836334, 0.37038587, 0.67539221, 0.77099063, 0.92992129, 0.56789747}, {0.67568569, 0.37884472, 0.18745406, 0.04757457, 0.09661771, 0.50471931, 0.35367252, 0.75794935, 0.6424804, 0.55250168}, {0.19722479, 0.32117211, 0.70339706, 0.53906674, 0.76903061, 0.32923893, 0.50025901, 0.20776133, 0.1088789, 0.79303772}, {0.31128645, 0.05883037, 0.64210569, 0.88726458, 0.19756748, 0.02448866, 0.2172705, 0.27894779, 0.55028519, 0.70483099}, {0.47339132, 0.14034869, 0.0816702, 0.06699631, 0.06823621, 0.03639515, 0.07545303, 0.1208853, 0.72845905, 0.74802801}, {0.99628077, 0.83760513, 0.63542635, 0.07380346, 0.79007766, 0.55288944, 0.44548098, 0.4055312, 0.70605767, 0.83153303}, {0.47161946, 0.97424448, 0.91217761, 0.6264732, 0.43486423, 0.39281956, 0.66218207, 0.01484187, 0.75595905, 0.04462323}});
    RealVector y = MatrixUtils.createRealVector((double[])new double[]{6.38923853, -2.16396995, 7.37162403, 1.79236199, 4.21888433, 0.41875855, -3.69136276, -0.50760573, 4.89875242, -4.03316984});
    List<Integer> correctActives = List.of(Integer.valueOf(7), Integer.valueOf(3), Integer.valueOf(0), Integer.valueOf(4), Integer.valueOf(8), Integer.valueOf(5), Integer.valueOf(9), Integer.valueOf(2), Integer.valueOf(1), Integer.valueOf(6));
    RealVector correctAlphas = MatrixUtils.createRealVector((double[])new double[]{1.56169836, 0.83963397, 0.68991047, 0.6664122, 0.29931992, 0.14315316, 0.1200302, 0.00776273, 0.00389666, 0.00187156, 0.0});
    RealMatrix X2 = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}});
    RealVector y2 = MatrixUtils.createRealVector((double[])new double[]{1.09622926, 1.09622926, 1.07290478, 0.5044599, 1.07290478, 0.88775703, 1.79883855, 0.0, 0.0, 1.39511415, 0.88775703, 0.29888489, 1.0524775, 0.0, 1.98398629});
    List<Integer> correctActives2 = List.of(Integer.valueOf(1), Integer.valueOf(2), Integer.valueOf(3), Integer.valueOf(4), Integer.valueOf(0));
    RealVector correctAlphas2 = MatrixUtils.createRealVector((double[])new double[]{0.61828706, 0.54926307, 0.36443261, 0.183918, 0.04809642, 0.0});
    RealMatrix XMinAlpha = MatrixUtils.createRealMatrix((double[][])new double[][]{{1.0, 2.0, 3.0, 4.0, 5.0}, {2.0, 4.0, 6.0, 8.0, 10.0}, {3.0, 6.0, 8.0, 12.0, 15.0}, {-1.0, -2.0, -3.0, -4.0, -5.0}});
    RealVector yMinAlpha = MatrixUtils.createRealVector((double[])new double[]{55.0, 110.0, 162.0, -55.0});
    List<Integer> correctActivesMinAlpha = List.of(Integer.valueOf(4), Integer.valueOf(2));
    RealVector correctAlphasMinAlpha = MatrixUtils.createRealVector((double[])new double[]{1020.0, 0.681818182, 4.16888746E-13});
    RealMatrix XDGR = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.0, 1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0, 9.0}, {10.0, 11.0, 12.0, 13.0, 14.0}, {15.0, 16.0, 17.0, 18.0, 19.0}, {20.0, 21.0, 22.0, 23.0, 24.0}, {25.0, 26.0, 27.0, 28.0, 29.0}, {30.0, 31.0, 32.0, 33.0, 34.0}, {35.0, 36.0, 37.0, 38.0, 39.0}, {40.0, 41.0, 42.0, 43.0, 44.0}, {45.0, 46.0, 47.0, 48.0, 49.0}});
    RealVector yDGR = MatrixUtils.createRealVector((double[])new double[]{0.0, 50.0, 100.0, 150.0, 200.0, 250.0, 300.0, 350.0, 400.0, 450.0});
    RealVector dummyWeights = this.yDGR.map(x -> 1.0);
    RealMatrix XVarDrop = MatrixUtils.createRealMatrix((double[][])new double[][]{{0.18321534, -0.35029812, 0.07666221, -0.0143539, 0.07493564, 0.0264753}, {-0.02852663, 0.35584493, 0.49064148, -0.25236788, 0.61892759, -0.31065164}, {0.21982559, 0.49110281, -0.5332698, -0.68922198, -0.52132598, 0.56451554}, {0.61981706, 0.26882512, 0.50779668, 0.36052515, 0.13083871, 0.41441906}, {-0.63640374, -0.65128988, -0.45031639, 0.57505155, 0.22006143, -0.63980788}, {-0.35792762, -0.11418486, -0.09151419, 0.02036707, -0.52343739, -0.05495039}});
    RealVector yVarDrop = MatrixUtils.createRealVector((double[])new double[]{0.00357273, 0.008411, 0.33522509, 0.07329731, -0.24901509, -0.17149104});
    List<Integer> correctActivesVarDrop = List.of(Integer.valueOf(1), Integer.valueOf(3), Integer.valueOf(0), Integer.valueOf(2), Integer.valueOf(4));
    RealVector correctAlphasVarDrop = MatrixUtils.createRealVector((double[])new double[]{0.0643070982530952, 0.0545709061429076, 0.051526183371599, 0.0273483558266873, 0.0086875666842567, 0.0056182106014212, 0.0037483228269213, 0.0});

    LarsPathTest() {
    }

    @ParameterizedTest
    @ValueSource(ints={2, 3, 4, 5, 6, 7, 8, 9, 10})
    void testLars10(int maxIter) {
        LarsPathResults lpr = LarsPath.fit((RealMatrix)this.X, (RealVector)this.y, (int)maxIter, (boolean)false);
        Assertions.assertEquals(this.correctActives.subList(0, maxIter), lpr.getActive().subList(0, maxIter));
        Assertions.assertArrayEquals((double[])this.correctAlphas.getSubVector(0, maxIter).toArray(), (double[])lpr.getAlphas().getSubVector(0, maxIter).toArray(), (double)1.0E-6);
    }

    @ParameterizedTest
    @ValueSource(ints={5})
    void testLars5(int maxIter) {
        LarsPathResults lpr = LarsPath.fit((RealMatrix)this.X2, (RealVector)this.y2, (int)maxIter, (boolean)false);
        Assertions.assertEquals(this.correctActives2.subList(0, maxIter), lpr.getActive().subList(0, maxIter));
        Assertions.assertArrayEquals((double[])this.correctAlphas2.getSubVector(0, maxIter).toArray(), (double[])lpr.getAlphas().getSubVector(0, maxIter).toArray(), (double)1.0E-6);
    }

    @Test
    void testLarsMinAlpha() {
        LarsPathResults lpr = LarsPath.fit((RealMatrix)this.XMinAlpha, (RealVector)this.yMinAlpha, (int)500, (boolean)false);
        Assertions.assertEquals(this.correctActivesMinAlpha, (Object)lpr.getActive());
        Assertions.assertArrayEquals((double[])this.correctAlphasMinAlpha.toArray(), (double[])lpr.getAlphas().toArray(), (double)1.0E-6);
    }

    @Test
    void testLarsDGR() {
        LarsPathResults lpr = LarsPath.fit((RealMatrix)this.XDGR, (RealVector)this.yDGR, (int)500, (boolean)false);
        RealMatrix coefs = lpr.getCoefs();
        double mse = WeightedLinearRegression.getMSE((RealMatrix)this.XDGR, (RealVector)this.yDGR, (RealVector)this.dummyWeights, (RealVector)coefs.getColumnVector(coefs.getColumnDimension() - 1));
        Assertions.assertTrue((mse < 1.0E-16 ? 1 : 0) != 0);
    }

    @Test
    void testLarsVarDrop() {
        LarsPathResults lpr = LarsPath.fit((RealMatrix)this.XVarDrop, (RealVector)this.yVarDrop, (int)500, (boolean)true);
        Assertions.assertEquals(this.correctActivesVarDrop, (Object)lpr.getActive());
        Assertions.assertArrayEquals((double[])this.correctAlphasVarDrop.toArray(), (double[])lpr.getAlphas().toArray(), (double)1.0E-6);
    }

    @Test
    void testLarsMismatchedInputs() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> LarsPath.fit((RealMatrix)this.XVarDrop, (RealVector)this.yDGR, (int)500, (boolean)true));
    }
}

