/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.Precision;
import org.kie.kogito.explainability.utils.LarsPathDataCarrier;
import org.kie.kogito.explainability.utils.LarsPathResults;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LarsPath {
    private static final Logger LOGGER = LoggerFactory.getLogger(LarsPath.class);

    private LarsPath() {
        throw new IllegalStateException("Utility class");
    }

    private static void updateCovarianceTrackers(LarsPathDataCarrier lpdc) {
        if (lpdc.getCov().getDimension() > 0) {
            lpdc.setcIdx(lpdc.getCov().map(Math::abs).map(x -> Precision.round((double)x, (int)16)).getMaxIndex());
            lpdc.setC_(lpdc.getCov().getEntry(lpdc.getcIdx()));
        } else {
            lpdc.setC_(0.0);
            lpdc.setcIdx(0);
        }
        lpdc.setC(Math.abs(lpdc.getC_()));
        lpdc.getAlphas().setEntry(lpdc.getnIter(), lpdc.getC() / (double)lpdc.getnSamples());
    }

    private static void checkRegressorDegeneracy(double diag, LarsPathDataCarrier lpdc) {
        if (diag < 1.0E-7) {
            String logMessage = String.format("Regressors in active set degenerate.Dropping a regressor, after %d iterations, Reduce max_iter or increase eps parameters.", lpdc.getnIter());
            LOGGER.warn(logMessage);
            lpdc.setCov(lpdc.getCovNotShortened());
            lpdc.getCov().setEntry(0, 0.0);
            MatrixUtilsExtensions.swap(lpdc.getCov(), lpdc.getcIdx(), 0);
            lpdc.setDegenerateRegressor(true);
        } else {
            lpdc.setDegenerateRegressor(false);
        }
    }

    private static RealMatrix computeGram(LarsPathDataCarrier lpdc, boolean simulateNActiveIncrement) {
        int adj = simulateNActiveIncrement ? 0 : 1;
        RealMatrix xtSubset = lpdc.getXT().getSubMatrix(0, lpdc.getnActive() - adj, 0, lpdc.getnSamples() - 1);
        return MatrixUtilsExtensions.matrixDot(xtSubset, xtSubset.transpose());
    }

    private static void getCholeskyDecomposition(LarsPathDataCarrier lpdc) {
        int nActive = lpdc.getnActive();
        if (!lpdc.isDrop()) {
            int cIdx = lpdc.getcIdx();
            double c_ = lpdc.getC_();
            lpdc.getSignActive().setEntry(nActive, Math.signum(c_));
            int n = cIdx + nActive;
            MatrixUtilsExtensions.swap(lpdc.getCov(), cIdx, 0);
            MatrixUtilsExtensions.swap(lpdc.getIndices(), n, nActive);
            MatrixUtilsExtensions.swap(lpdc.getXT(), n, nActive);
            lpdc.setX(lpdc.getXT().transpose());
            lpdc.setCovNotShortened(lpdc.getCov().copy());
            lpdc.setCov(lpdc.getCov().getSubVector(1, lpdc.getCov().getDimension() - 1));
            CholeskyDecomposition decomp = new CholeskyDecomposition(LarsPath.computeGram(lpdc, true));
            lpdc.setDecomp(decomp);
            RealMatrix lowerDefactored = decomp.getL();
            double diag = lowerDefactored.getEntry(lowerDefactored.getRowDimension() - 1, lowerDefactored.getColumnDimension() - 1);
            LarsPath.checkRegressorDegeneracy(diag, lpdc);
            if (lpdc.isDegenerateRegressor()) {
                return;
            }
            lpdc.getActive().add(lpdc.getIndices()[nActive]);
            lpdc.setnActive(nActive + 1);
        } else {
            CholeskyDecomposition decomp = new CholeskyDecomposition(LarsPath.computeGram(lpdc, false), 1.0E-16, -1.0E-12);
            lpdc.setDecomp(decomp);
        }
    }

    private static boolean minimumAlphaBreakCondition(LarsPathDataCarrier lpdc) {
        int nIter = lpdc.getnIter();
        int adj = nIter > 0 ? 1 : 0;
        RealVector alphas = lpdc.getAlphas();
        RealMatrix coefs = lpdc.getCoefs();
        RealVector coef = coefs.getRowVector(nIter);
        RealVector prevCoef = coefs.getRowVector(nIter - adj);
        if (alphas.getEntry(nIter) <= lpdc.getEqualityTolerance()) {
            if (Math.abs(alphas.getEntry(nIter)) > lpdc.getEqualityTolerance()) {
                if (nIter > 0) {
                    double ss = alphas.getEntry(nIter - 1) / (alphas.getEntry(nIter - 1) - alphas.getEntry(nIter));
                    coef = prevCoef.mapAdd(ss).ebeMultiply(coef.subtract(prevCoef));
                }
                lpdc.getAlphas().setEntry(nIter, 0.0);
            }
            lpdc.getCoefs().setRowVector(nIter, coef);
            return true;
        }
        return false;
    }

    private static boolean maximumIterationBreakCondition(LarsPathDataCarrier lpdc) {
        return lpdc.getnIter() >= lpdc.getMaxIterations() || lpdc.getnActive() >= lpdc.getnFeatures();
    }

    private static boolean earlyStoppingBreakCondition(LarsPathDataCarrier lpdc) {
        RealVector alphas = lpdc.getAlphas();
        int nIter = lpdc.getnIter();
        int nActive = lpdc.getnActive();
        if (nIter > 0 && alphas.getEntry(nIter - 1) < alphas.getEntry(nIter)) {
            String logMessage = String.format("Early stopping the lars path, as the residues are small and the current value of alpha is no longer well controlled. %d iterations, alpha=%.3f, previous alpha=%.3f, with an active set of %d regressors.", nIter, alphas.getEntry(nIter), alphas.getEntry(nIter - 1), nActive);
            LOGGER.warn(logMessage);
            return true;
        }
        return false;
    }

    private static void getNormalizedLeastSquares(LarsPathDataCarrier lpdc) {
        double normalizationFactor;
        RealVector signActiveSubset = lpdc.getSignActive().getSubVector(0, lpdc.getnActive());
        RealVector leastSquares = lpdc.getDecomp().getSolver().solve(signActiveSubset);
        if (leastSquares.getDimension() == 1 && leastSquares.getEntry(0) == 0.0) {
            leastSquares.setEntry(0, 1.0);
            normalizationFactor = 1.0;
        } else {
            normalizationFactor = 1.0 / Math.sqrt(Arrays.stream(leastSquares.ebeMultiply(signActiveSubset).toArray()).sum());
            leastSquares.mapMultiplyToSelf(normalizationFactor);
        }
        lpdc.setLeastSquares(leastSquares);
        lpdc.setNormalizationFactor(normalizationFactor);
    }

    private static void getCorrelationDirection(LarsPathDataCarrier lpdc) {
        double gamma;
        RealVector corrEqDir;
        RealVector eqDir = lpdc.getXT().getSubMatrix(0, lpdc.getnActive() - 1, 0, lpdc.getnSamples() - 1).transpose().operate(lpdc.getLeastSquares());
        int nActive = lpdc.getnActive();
        int nFeatures = lpdc.getnFeatures();
        int nSamples = lpdc.getnSamples();
        double c = lpdc.getC();
        double normalizationFactor = lpdc.getNormalizationFactor();
        if (nActive < nFeatures) {
            corrEqDir = lpdc.getXT().getSubMatrix(nActive, nFeatures - 1, 0, nSamples - 1).operate(eqDir);
            corrEqDir.mapToSelf(x -> Precision.round((double)x, (int)16));
            RealVector cov = lpdc.getCov();
            double tiny = lpdc.getTiny();
            double g1 = MatrixUtilsExtensions.minPos(cov.map(x -> c - x).ebeDivide(corrEqDir.map(x -> normalizationFactor - x + tiny)));
            double g2 = MatrixUtilsExtensions.minPos(cov.mapAdd(c).ebeDivide(corrEqDir.mapAdd(normalizationFactor + tiny)));
            gamma = Math.min(Math.min(g1, g2), c / normalizationFactor);
        } else {
            corrEqDir = MatrixUtils.createRealVector((double[])new double[0]);
            gamma = c / normalizationFactor;
        }
        lpdc.setCorrEqDir(corrEqDir);
        lpdc.setGamma(gamma);
    }

    private static void setZInformation(LarsPathDataCarrier lpdc) {
        RealVector coef = lpdc.getCoefs().getRowVector(lpdc.getnIter());
        lpdc.setZ(MatrixUtils.createRealVector((double[])lpdc.getActive().stream().mapToDouble(a -> -coef.getEntry(a.intValue())).toArray()).ebeDivide(lpdc.getLeastSquares().mapAdd(lpdc.getEqualityTolerance())));
        lpdc.setzPos(MatrixUtilsExtensions.minPos(lpdc.getZ()));
    }

    private static void getActiveIndices(LarsPathDataCarrier lpdc) {
        lpdc.setDrop(false);
        double zPos = lpdc.getzPos();
        RealVector z = lpdc.getZ();
        if (zPos < lpdc.getGamma()) {
            lpdc.setIdx(IntStream.range(0, z.getDimension()).filter(i -> Math.abs(z.getEntry(i) - zPos) < lpdc.getEqualityTolerance()).boxed().collect(Collectors.toSet()));
            for (int i2 : lpdc.getIdx()) {
                lpdc.getSignActive().setEntry(i2, -lpdc.getSignActive().getEntry(i2));
            }
            if (lpdc.isLasso()) {
                lpdc.setGamma(zPos);
            }
            lpdc.setDrop(true);
        }
    }

    private static void trackCoefficientsAndAlphas(LarsPathDataCarrier lpdc) {
        int nIter = lpdc.getnIter();
        int nActive = lpdc.getnActive();
        int nFeatures = lpdc.getnFeatures();
        int maxFeatures = lpdc.getMaxFeatures();
        List<Integer> active = lpdc.getActive();
        if (nIter >= lpdc.getCoefs().getRowDimension()) {
            int addFeatures = 2 * Math.max(1, maxFeatures - nActive);
            RealMatrix newCoefs = MatrixUtils.createRealMatrix((int)(nIter + addFeatures), (int)nFeatures);
            newCoefs.setSubMatrix(lpdc.getCoefs().getData(), 0, 0);
            lpdc.setCoefs(newCoefs);
            lpdc.setAlphas(lpdc.getAlphas().append(MatrixUtils.createRealVector((double[])new double[addFeatures])));
        }
        for (int i = 0; i < active.size(); ++i) {
            double newVal = lpdc.getCoefs().getEntry(nIter - 1, active.get(i).intValue()) + lpdc.getGamma() * lpdc.getLeastSquares().getEntry(i);
            lpdc.getCoefs().setEntry(nIter, active.get(i).intValue(), newVal);
        }
    }

    private static void adjustCovarianceByCorrelationDirection(LarsPathDataCarrier lpdc) {
        if (lpdc.getCorrEqDir().getDimension() > 0) {
            lpdc.setCov(lpdc.getCov().subtract(lpdc.getCorrEqDir().mapMultiply(lpdc.getGamma())));
        }
    }

    private static void dropFeature(LarsPathDataCarrier lpdc) {
        lpdc.setnActive(lpdc.getnActive() - 1);
        int nActive = lpdc.getnActive();
        lpdc.setActive(IntStream.range(0, lpdc.getActive().size()).filter(i -> !lpdc.getIdx().contains(i)).map(i -> lpdc.getActive().get(i)).boxed().collect(Collectors.toList()));
        RealMatrix coefs = lpdc.getCoefs();
        RealVector y = lpdc.getY();
        Iterator<Integer> iterator = lpdc.getIdx().iterator();
        while (iterator.hasNext()) {
            int ii;
            for (int i2 = ii = iterator.next().intValue(); i2 < nActive; ++i2) {
                MatrixUtilsExtensions.swap(lpdc.getXT(), i2, i2 + 1);
                lpdc.setX(lpdc.getXT().transpose());
                MatrixUtilsExtensions.swap(lpdc.getIndices(), i2, i2 + 1);
            }
            RealVector activeCoef = MatrixUtils.createRealVector((double[])lpdc.getActive().stream().mapToDouble(a -> coefs.getEntry(lpdc.getnIter(), a.intValue())).toArray());
            RealVector residual = y.subtract(lpdc.getX().getSubMatrix(0, lpdc.getX().getRowDimension() - 1, 0, nActive - 1).operate(activeCoef));
            lpdc.setCov(MatrixUtils.createRealVector((double[])new double[]{lpdc.getXT().getRowVector(nActive).dotProduct(residual)}).append(lpdc.getCov()));
        }
        lpdc.setSignActive(MatrixUtils.createRealVector((double[])IntStream.range(0, lpdc.getSignActive().getDimension()).filter(i -> !lpdc.getIdx().contains(i)).mapToDouble(i -> lpdc.getSignActive().getEntry(i)).toArray()));
        lpdc.setSignActive(lpdc.getSignActive().append(0.0));
    }

    private static LarsPathResults truncatedAndFormattedResults(LarsPathDataCarrier lpdc) {
        int nIter = lpdc.getnIter();
        RealVector rawAlphas = lpdc.getAlphas();
        RealMatrix rawCoefs = lpdc.getCoefs();
        if (nIter + 1 < rawAlphas.getDimension()) {
            lpdc.setAlphas(rawAlphas.getSubVector(0, Math.min(nIter + 1, rawAlphas.getDimension())));
            lpdc.setCoefs(rawCoefs.getSubMatrix(0, Math.min(nIter, rawCoefs.getRowDimension() - 1), 0, lpdc.getnFeatures() - 1));
        }
        return new LarsPathResults(lpdc.getCoefs().transpose(), lpdc.getAlphas(), lpdc.getActive(), nIter);
    }

    public static LarsPathResults fit(RealMatrix X, RealVector y, int maxIterations, boolean lasso) {
        if (X.getRowDimension() != y.getDimension()) {
            throw new IllegalArgumentException(String.format("Number of rows of X (%d) must match number of entries in y (%d)!", X.getRowDimension(), y.getDimension()));
        }
        LarsPathDataCarrier lpdc = new LarsPathDataCarrier(X, y, maxIterations, lasso);
        while (true) {
            LarsPath.updateCovarianceTrackers(lpdc);
            if (LarsPath.minimumAlphaBreakCondition(lpdc) || LarsPath.maximumIterationBreakCondition(lpdc)) break;
            LarsPath.getCholeskyDecomposition(lpdc);
            if (lpdc.isDegenerateRegressor()) continue;
            if (lpdc.isLasso() && LarsPath.earlyStoppingBreakCondition(lpdc)) break;
            LarsPath.getNormalizedLeastSquares(lpdc);
            LarsPath.getCorrelationDirection(lpdc);
            LarsPath.setZInformation(lpdc);
            LarsPath.getActiveIndices(lpdc);
            lpdc.incrementnIter();
            LarsPath.trackCoefficientsAndAlphas(lpdc);
            LarsPath.adjustCovarianceByCorrelationDirection(lpdc);
            if (!lpdc.isLasso() || !lpdc.isDrop()) continue;
            LarsPath.dropFeature(lpdc);
        }
        return LarsPath.truncatedAndFormattedResults(lpdc);
    }
}

