/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.kie.kogito.explainability.utils.LarsPath;
import org.kie.kogito.explainability.utils.LarsPathResults;
import org.kie.kogito.explainability.utils.LassoLarsICResults;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;

public class LassoLarsIC {
    private LassoLarsIC() {
        throw new IllegalStateException("Utility class");
    }

    public static LassoLarsICResults fit(RealMatrix X, RealVector y, Criterion c) {
        return LassoLarsIC.fit(X, y, c, X.getColumnDimension() * 200);
    }

    public static LassoLarsICResults fit(RealMatrix X, RealVector y, Criterion c, int maxIterations) {
        int nSamples = X.getRowDimension();
        double epsilon32 = Math.ulp(1.0f);
        double epsilon64 = Math.ulp(1.0);
        RealVector xMean = MatrixUtilsExtensions.rowSum(X).mapDivide((double)nSamples);
        double yMean = Arrays.stream(y.toArray()).sum() / (double)nSamples;
        RealMatrix xCenter = MatrixUtilsExtensions.vectorDifference(X, xMean, MatrixUtilsExtensions.Axis.ROW);
        RealVector yCenter = y.mapSubtract(yMean);
        LarsPathResults lpResults = LarsPath.fit(xCenter, yCenter, maxIterations, true);
        double K = c == Criterion.AIC ? 2.0 : Math.log(nSamples);
        RealMatrix residuals = MatrixUtilsExtensions.vectorDifference(MatrixUtilsExtensions.matrixDot(xCenter, lpResults.getCoefs()), yCenter, MatrixUtilsExtensions.Axis.COLUMN);
        RealVector mse = MatrixUtilsExtensions.rowSquareSum(residuals).mapDivide((double)residuals.getRowDimension());
        double sigma2 = MatrixUtilsExtensions.variance(yCenter);
        RealVector dof = MatrixUtils.createRealVector((double[])new double[lpResults.getCoefs().getColumnDimension()]);
        RealMatrix coefT = lpResults.getCoefs().transpose();
        for (int k = 0; k < coefT.getRowDimension(); ++k) {
            RealVector mask = coefT.getRowVector(k).map(x -> Math.abs(x) > epsilon32 ? 1.0 : 0.0);
            double maskSum = Arrays.stream(mask.toArray()).sum();
            if (maskSum == 0.0) continue;
            dof.setEntry(k, maskSum);
        }
        RealVector criterion = mse.mapMultiply((double)nSamples).mapDivide(sigma2 + epsilon64).add(dof.mapMultiply(K));
        int best = criterion.getMinIndex();
        RealVector bestCoef = lpResults.getCoefs().getColumnVector(best);
        double bestAlpha = lpResults.getAlphas().getEntry(best);
        double intercept = yMean - xMean.dotProduct(bestCoef);
        return new LassoLarsICResults(bestCoef, bestAlpha, intercept);
    }

    public static enum Criterion {
        AIC,
        BIC;

    }
}

