package org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.setmembership;

import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.DelegateFragment;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.variables.SchnorrVariableAssignment;
import org.cryptimeleon.craco.protocols.arguments.sigma.schnorr.variables.SchnorrZnVariable;
import org.cryptimeleon.math.expressions.exponent.ExponentExpr;
import org.cryptimeleon.math.structures.groups.elliptic.BilinearGroup;
import org.cryptimeleon.math.structures.rings.zn.Zn;

import java.math.BigInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * A fragment for the statement {@code lowerBound <= member <= upperBound}.
 */
public class TwoSidedRangeProof extends DelegateFragment {
    private final ExponentExpr member;
    private final Zn.ZnElement lowerBound, upperBound;
    private final SetMembershipPublicParameters pp;
    private final int base;
    private final int power;

    /**
     * Instantiates the proof for member in [lowerBound, upperBound] (inclusive).
     * @param member an expression whose value shall be in the given interval. In the easiest case, this is a {@link SchnorrZnVariable},
     *               but it can be any affine linear combination of {@link SchnorrZnVariable}s
     * @param lowerBound lower bound (inclusive)
     * @param upperBound upper bound (inclusive)
     * @param pp honestly generated public parameters for a set membership proof for {0, ..., base} for an arbitrary integer base (large base means faster protocol).
     *           Can be generated by {@link TwoSidedRangeProof#generatePublicParameters(BilinearGroup, int)}
     */
    public TwoSidedRangeProof(ExponentExpr member, Zn.ZnElement lowerBound, Zn.ZnElement upperBound, SetMembershipPublicParameters pp) {
        this.member = member;
        this.lowerBound = lowerBound;
        this.upperBound = upperBound;
        this.pp = pp;

        this.base = pp.signatures.size();
        BigInteger intervalSize = upperBound.asInteger().subtract(lowerBound.asInteger());
        if (intervalSize.signum() < 0)
            throw new IllegalArgumentException("upper bound must be larger than lower bound");

        int power = intervalSize.bitLength();
        while (BigInteger.valueOf(base).pow(power).compareTo(intervalSize) <= 0)
            power++;

        this.power = power;

        if (lowerBound.asInteger().add(BigInteger.valueOf(base).pow(power)).compareTo(pp.getZn().size()) > 0)
            throw new IllegalArgumentException("Interval is too close to the mod p overflow boundary (i.e. numbers in the interval are too large - choose smaller numbers)");
    }

    /**
     * Instantiates the proof for member in [lowerBound, upperBound] (inclusive).
     * @param member an expression whose value shall be in the given interval. In the easiest case, this is a {@link SchnorrZnVariable},
     *               but it can be any affine linear combination of {@link SchnorrZnVariable}s
     * @param lowerBound lower bound (inclusive)
     * @param upperBound upper bound (inclusive)
     * @param pp honestly generated public parameters for a set membership proof for {0, ..., base} for an arbitrary integer base (large base means faster protocol).
     *           Can be generated by {@link TwoSidedRangeProof#generatePublicParameters(BilinearGroup, int)}
     */
    public TwoSidedRangeProof(ExponentExpr member, int lowerBound, int upperBound, SetMembershipPublicParameters pp) {
        this(member, BigInteger.valueOf(lowerBound), BigInteger.valueOf(upperBound), pp);
    }

    /**
     * Instantiates the proof for member in [lowerBound, upperBound] (inclusive).
     * @param member an expression whose value shall be in the given interval. In the easiest case, this is a {@link SchnorrZnVariable},
     *               but it can be any affine linear combination of {@link SchnorrZnVariable}s
     * @param lowerBound lower bound (inclusive)
     * @param upperBound upper bound (inclusive)
     * @param pp honestly generated public parameters for a set membership proof for {0, ..., base} for an arbitrary integer base (large base means faster protocol).
     *           Can be generated by {@link TwoSidedRangeProof#generatePublicParameters(BilinearGroup, int)}
     */
    public TwoSidedRangeProof(ExponentExpr member, BigInteger lowerBound, BigInteger upperBound, SetMembershipPublicParameters pp) {
        this(member, pp.getZn().valueOf(lowerBound), pp.getZn().valueOf(upperBound), pp);
    }

    @Override
    protected ProverSpec provideProverSpecWithNoSendFirst(SchnorrVariableAssignment externalWitnesses, ProverSpecBuilder builder) {
        //Nothing to do, no new variables set
        return builder.build();
    }

    @Override
    protected SubprotocolSpec provideSubprotocolSpec(SubprotocolSpecBuilder builder) {
        builder.addSubprotocol("member-lowerBound >= 0", new SmallerThanPowerFragment(member.sub(lowerBound), base, power, pp));
        builder.addSubprotocol("upperBound-member >= 0", new SmallerThanPowerFragment(upperBound.asExponentExpression().sub(member), base, power, pp));

        return builder.build();
    }

    /**
     * Generates public parameters to use for this protocol (alternatively, these can be reused from {@link SmallerThanPowerFragment})
     * @param group the group to use for this fragment
     * @param base the desired base (the bigger this is, the larger these parameters become storage-wise, but the faster and shorter the proof becomes)
     */
    public static SetMembershipPublicParameters generatePublicParameters(BilinearGroup group, int base) {
        return SetMembershipPublicParameters.generate(group, IntStream.range(0, base).mapToObj(BigInteger::valueOf).collect(Collectors.toSet()));
    }
}
