/*
 * Decompiled with CFR 0.152.
 */
package org.cryptimeleon.math.structures.groups.exp;

import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import org.cryptimeleon.math.structures.groups.GroupElementImpl;
import org.cryptimeleon.math.structures.groups.GroupImpl;
import org.cryptimeleon.math.structures.groups.exp.MultiExpAlgorithm;
import org.cryptimeleon.math.structures.groups.exp.MultiExpTerm;
import org.cryptimeleon.math.structures.groups.exp.Multiexponentiation;
import org.cryptimeleon.math.structures.groups.exp.SmallExponentPrecomputation;

public class ExponentiationAlgorithms {
    public static final double WNAF_INVERSION_COST_THRESHOLD = 1.5;

    public static GroupElementImpl interleavingSlidingWindowMultiExp(Multiexponentiation multiexp, int windowSize) {
        List<MultiExpTerm> terms = multiexp.getTerms();
        multiexp.ensurePrecomputation(windowSize, MultiExpAlgorithm.SLIDING);
        if (terms.isEmpty()) {
            return multiexp.getConstantFactor().orElseThrow(() -> new IllegalArgumentException("Cannot compute an empty multiexp"));
        }
        int numTerms = terms.size();
        GroupElementImpl result = terms.get(0).getBase().getStructure().getNeutralElement();
        int longestExponentBitLength = terms.stream().mapToInt(t -> t.getExponent().bitLength()).max().getAsInt();
        int[] windowPos = new int[numTerms];
        Arrays.fill(windowPos, -1);
        int[] windowVal = new int[numTerms];
        for (int j = longestExponentBitLength - 1; j >= 0; --j) {
            if (j != longestExponentBitLength - 1) {
                result = result.square();
            }
            for (int i = 0; i < numTerms; ++i) {
                BigInteger exponent = terms.get(i).getExponent();
                boolean exponentNegative = exponent.signum() < 0;
                BigInteger bigInteger = exponent = exponentNegative ? exponent.negate() : exponent;
                if (windowPos[i] == -1 && exponent.testBit(j)) {
                    int J = j - windowSize + 1;
                    while (!ExponentiationAlgorithms.testBit(exponent, J)) {
                        ++J;
                    }
                    windowPos[i] = J;
                    windowVal[i] = 0;
                    for (int k = j; k >= J; --k) {
                        int n = i;
                        windowVal[n] = windowVal[n] << 1;
                        if (!ExponentiationAlgorithms.testBit(exponent, k)) continue;
                        int n2 = i;
                        windowVal[n2] = windowVal[n2] + 1;
                    }
                }
                if (windowPos[i] != j) continue;
                result = result.op(terms.get(i).getPrecomputation().get(exponentNegative ? -windowVal[i] : windowVal[i]));
                windowPos[i] = -1;
            }
        }
        result = multiexp.getConstantFactor().map(result::op).orElse(result);
        return result;
    }

    public static GroupElementImpl interleavingWnafMultiExp(Multiexponentiation multiexp, int windowSize) {
        GroupElementImpl neutral;
        multiexp.ensurePrecomputation(windowSize, MultiExpAlgorithm.WNAF);
        List<MultiExpTerm> terms = multiexp.getTerms();
        if (terms.isEmpty()) {
            return multiexp.getConstantFactor().orElseThrow(() -> new IllegalArgumentException("Cannot compute an empty multiexp"));
        }
        int longestExponentDigitLength = 0;
        int[][] exponentDigits = new int[terms.size()][];
        for (int i = 0; i < terms.size(); ++i) {
            exponentDigits[i] = ExponentiationAlgorithms.precomputeExponentDigitsForWnaf(terms.get((int)i).exponent, windowSize);
            longestExponentDigitLength = Math.max(longestExponentDigitLength, exponentDigits[i].length);
        }
        GroupImpl group = terms.get((int)0).base.getStructure();
        GroupElementImpl result = neutral = group.getNeutralElement();
        for (int j = longestExponentDigitLength - 1; j >= 0; --j) {
            if (result != neutral) {
                result = result.square();
            }
            for (int i = 0; i < exponentDigits.length; ++i) {
                int exponentDigit;
                if (exponentDigits[i].length <= j || (exponentDigit = exponentDigits[i][j]) == 0) continue;
                result = result.op(terms.get(i).getPrecomputation().get(exponentDigit));
            }
        }
        result = multiexp.getConstantFactor().map(result::op).orElse(result);
        return result;
    }

    private static boolean testBit(BigInteger n, int index) {
        if (index < 0) {
            return false;
        }
        return n.testBit(index);
    }

    public static GroupElementImpl binSquareMultiplyExp(GroupElementImpl base, BigInteger k) {
        if (k.signum() < 0) {
            return ExponentiationAlgorithms.binSquareMultiplyExp(base, k.negate()).inv();
        }
        GroupElementImpl result = base.getStructure().getNeutralElement();
        for (int i = k.bitLength() - 1; i >= 0; --i) {
            result = result.op(result);
            if (!k.testBit(i)) continue;
            result = result.op(base);
        }
        return result;
    }

    public static GroupElementImpl slidingWindowExp(GroupElementImpl base, BigInteger exponent, SmallExponentPrecomputation precomputation, int windowSize) {
        boolean invertExisting;
        if (precomputation == null) {
            precomputation = new SmallExponentPrecomputation(base);
        }
        boolean bl = invertExisting = base.getStructure().estimateCostInvPerOp() > 1.0;
        if (exponent.signum() < 0) {
            windowSize = Math.max(precomputation.getCurrentlySupportedNegativeWindowSize(), windowSize);
            precomputation.computeNegativePowers(windowSize, invertExisting);
        } else {
            windowSize = Math.max(precomputation.getCurrentlySupportedPositiveWindowSize(), windowSize);
            precomputation.compute(windowSize, invertExisting);
        }
        GroupElementImpl result = base.getStructure().getNeutralElement();
        boolean exponentNegative = exponent.signum() < 0;
        BigInteger posExponent = exponentNegative ? exponent.negate() : exponent;
        int exponentBitlen = posExponent.bitLength();
        int windowPos = -1;
        int windowVal = 0;
        for (int j = exponentBitlen - 1; j >= 0; --j) {
            if (j != exponentBitlen - 1) {
                result = result.square();
            }
            if (windowPos == -1 && posExponent.testBit(j)) {
                int J = j - windowSize + 1;
                while (!ExponentiationAlgorithms.testBit(posExponent, J)) {
                    ++J;
                }
                windowPos = J;
                windowVal = 0;
                for (int k = j; k >= J; --k) {
                    windowVal <<= 1;
                    if (!ExponentiationAlgorithms.testBit(posExponent, k)) continue;
                    ++windowVal;
                }
            }
            if (windowPos != j) continue;
            result = result.op(precomputation.get(exponentNegative ? -windowVal : windowVal));
            windowPos = -1;
        }
        return result;
    }

    public static GroupElementImpl wnafExp(GroupElementImpl base, BigInteger exponent, SmallExponentPrecomputation precomputation, int windowSize) {
        GroupElementImpl neutral;
        if (precomputation == null) {
            precomputation = new SmallExponentPrecomputation(base);
        } else {
            windowSize = Math.max(precomputation.getCurrentlySupportedWindowSize(), windowSize);
        }
        if (precomputation.getCurrentlySupportedNegativeWindowSize() > precomputation.getCurrentlySupportedPositiveWindowSize()) {
            precomputation.computeNegativePowers(windowSize, false);
        } else {
            precomputation.compute(windowSize, false);
        }
        int[] exponentDigits = ExponentiationAlgorithms.precomputeExponentDigitsForWnaf(exponent, windowSize);
        int exponentDigitsLen = exponentDigits.length;
        GroupImpl group = base.getStructure();
        GroupElementImpl result = neutral = group.getNeutralElement();
        for (int j = exponentDigitsLen - 1; j >= 0; --j) {
            int exponentDigit;
            if (result != neutral) {
                result = result.square();
            }
            if ((exponentDigit = exponentDigits[j]) == 0) continue;
            result = result.op(precomputation.get(exponentDigit));
        }
        return result;
    }

    public static int getNLeastSignificantBits(long i, int numberOfLowBits) {
        return (int)(i & (long)((1 << numberOfLowBits) - 1));
    }

    public static int[] precomputeExponentDigitsForWnaf(BigInteger exponent, int windowSize) {
        if (windowSize > 30) {
            throw new IllegalArgumentException("Cannot handle window sizes > 30");
        }
        boolean invertEverything = false;
        if (exponent.signum() < 0) {
            invertEverything = true;
            exponent = exponent.negate();
        }
        byte[] c = exponent.toByteArray();
        int bitsCurrentlyLoaded = 0;
        int byteArrayIndexToLoadNext = 0;
        long currentValue = 0L;
        int[] result = new int[exponent.bitLength() + 1];
        int i = 0;
        while (currentValue != 0L || byteArrayIndexToLoadNext < c.length) {
            int shiftAmount;
            while (bitsCurrentlyLoaded < 32 && byteArrayIndexToLoadNext < c.length) {
                currentValue += Byte.toUnsignedLong(c[c.length - byteArrayIndexToLoadNext - 1]) << bitsCurrentlyLoaded;
                bitsCurrentlyLoaded += 8;
                ++byteArrayIndexToLoadNext;
            }
            if ((currentValue & 1L) == 1L) {
                int digit = ExponentiationAlgorithms.getNLeastSignificantBits(currentValue, windowSize + 1);
                if (digit >= 1 << windowSize) {
                    digit -= 1 << windowSize + 1;
                }
                result[i] = invertEverything ? -digit : digit;
                currentValue -= (long)digit;
                shiftAmount = windowSize;
            } else {
                shiftAmount = Math.min(Long.numberOfTrailingZeros(currentValue), bitsCurrentlyLoaded);
            }
            i += shiftAmount;
            currentValue >>= shiftAmount;
            bitsCurrentlyLoaded -= shiftAmount;
        }
        return result;
    }
}

