/*
 * Decompiled with CFR 0.152.
 */
package org.cryptimeleon.craco.sig.sps.agho11;

import java.util.Arrays;
import java.util.Objects;
import java.util.stream.IntStream;
import org.cryptimeleon.craco.common.plaintexts.GroupElementPlainText;
import org.cryptimeleon.craco.common.plaintexts.MessageBlock;
import org.cryptimeleon.craco.common.plaintexts.PlainText;
import org.cryptimeleon.craco.sig.MultiMessageStructurePreservingSignatureScheme;
import org.cryptimeleon.craco.sig.Signature;
import org.cryptimeleon.craco.sig.SignatureKeyPair;
import org.cryptimeleon.craco.sig.SigningKey;
import org.cryptimeleon.craco.sig.VerificationKey;
import org.cryptimeleon.craco.sig.sps.agho11.SPSAGHO11PublicParameters;
import org.cryptimeleon.craco.sig.sps.agho11.SPSAGHO11Signature;
import org.cryptimeleon.craco.sig.sps.agho11.SPSAGHO11SigningKey;
import org.cryptimeleon.craco.sig.sps.agho11.SPSAGHO11VerificationKey;
import org.cryptimeleon.math.serialization.ListRepresentation;
import org.cryptimeleon.math.serialization.Representation;
import org.cryptimeleon.math.serialization.annotations.ReprUtil;
import org.cryptimeleon.math.serialization.annotations.Represented;
import org.cryptimeleon.math.structures.Element;
import org.cryptimeleon.math.structures.cartesian.Vector;
import org.cryptimeleon.math.structures.groups.Group;
import org.cryptimeleon.math.structures.groups.GroupElement;
import org.cryptimeleon.math.structures.groups.elliptic.BilinearMap;
import org.cryptimeleon.math.structures.rings.zn.Zn;
import org.cryptimeleon.math.structures.rings.zn.Zp;

public class SPSAGHO11SignatureScheme
implements MultiMessageStructurePreservingSignatureScheme {
    @Represented
    protected SPSAGHO11PublicParameters pp;

    protected SPSAGHO11SignatureScheme() {
    }

    public SPSAGHO11SignatureScheme(SPSAGHO11PublicParameters pp) {
        this.pp = pp;
    }

    public SPSAGHO11SignatureScheme(Representation repr) {
        new ReprUtil((Object)this).deserialize(repr);
    }

    public SignatureKeyPair<SPSAGHO11VerificationKey, SPSAGHO11SigningKey> generateKeyPair(int numberOfMessages) {
        return this.generateKeyPair(numberOfMessages, 0);
    }

    public SignatureKeyPair<SPSAGHO11VerificationKey, SPSAGHO11SigningKey> generateKeyPair(int ... messageBlockLengths) {
        Zp zp = this.pp.getZp();
        if (messageBlockLengths.length != 2) {
            throw new IllegalArgumentException(String.format("The signature scheme AGHO11 expects to sign elements on two vectors G^M, H^N, but received: {0} vectors", messageBlockLengths.length));
        }
        for (int i = 0; i < messageBlockLengths.length; ++i) {
            if (messageBlockLengths[i] == this.pp.getMessageLengths()[i]) continue;
            throw new IllegalArgumentException(String.format("The given message length of the %s vector does not match the public parameters expected: %d, but was: %d", i == 0 ? "first" : "second", this.pp.getMessageLengths()[i], messageBlockLengths[i]));
        }
        int firstMsgVectorLength = Math.max(1, messageBlockLengths[0]);
        int secondMsgVectorLength = Math.max(2, messageBlockLengths[1]);
        Zp.ZpElement[] exponentsU = (Zp.ZpElement[])IntStream.range(0, secondMsgVectorLength).mapToObj(x -> zp.getUniformlyRandomNonzeroElement()).toArray(Zp.ZpElement[]::new);
        Zp.ZpElement[] exponentsW = (Zp.ZpElement[])IntStream.range(0, firstMsgVectorLength).mapToObj(x -> zp.getUniformlyRandomNonzeroElement()).toArray(Zp.ZpElement[]::new);
        Zp.ZpElement exponentV = zp.getUniformlyRandomNonzeroElement();
        Zp.ZpElement exponentZ = zp.getUniformlyRandomNonzeroElement();
        GroupElement[] groupElementsU = (GroupElement[])Arrays.stream(exponentsU).map(x -> this.pp.getG1GroupGenerator().pow((Zn.ZnElement)x).compute()).toArray(GroupElement[]::new);
        GroupElement[] groupElementsW = (GroupElement[])Arrays.stream(exponentsW).map(x -> this.pp.getG2GroupGenerator().pow((Zn.ZnElement)x).compute()).toArray(GroupElement[]::new);
        SPSAGHO11VerificationKey pk = new SPSAGHO11VerificationKey(groupElementsU, this.pp.getG2GroupGenerator().pow((Zn.ZnElement)exponentV).compute(), groupElementsW, this.pp.getG2GroupGenerator().pow((Zn.ZnElement)exponentZ).compute());
        SPSAGHO11SigningKey sk = new SPSAGHO11SigningKey(exponentsU, exponentV, exponentsW, exponentZ);
        return new SignatureKeyPair<SPSAGHO11VerificationKey, SPSAGHO11SigningKey>(pk, sk);
    }

    @Override
    public Signature sign(PlainText plainText, SigningKey secretKey) {
        this.doMessageChecks(plainText);
        if (secretKey.getClass() != SPSAGHO11SigningKey.class) {
            throw new IllegalArgumentException("Not a valid signing key for this scheme");
        }
        MessageBlock containerBlock = (MessageBlock)plainText;
        containerBlock = this.padMessageIfShort(containerBlock);
        MessageBlock messageGElements = (MessageBlock)containerBlock.get(0);
        MessageBlock messageHElements = (MessageBlock)containerBlock.get(1);
        int k_M = messageGElements.length();
        int k_N = messageHElements.length();
        SPSAGHO11SigningKey sk = (SPSAGHO11SigningKey)secretKey;
        Zp.ZpElement r = this.pp.getZp().getUniformlyRandomNonzeroElement();
        GroupElement sigma1R = this.pp.getG1GroupGenerator().pow((Zn.ZnElement)r).compute();
        GroupElement sigma2S1 = this.pp.getG1GroupGenerator().pow((Zn.ZnElement)sk.getExponentZ().sub((Element)r.mul((Element)sk.getExponentV())));
        GroupElement sigma2S2 = this.pp.getG1GroupGenerator().getStructure().getNeutralElement();
        for (int i = 0; i < k_M; ++i) {
            sigma2S2 = sigma2S2.op((Element)((GroupElementPlainText)messageGElements.get(i)).get().pow((Zn.ZnElement)sk.getExponentsW()[i].neg()));
        }
        GroupElement sigma2S = sigma2S1.op((Element)sigma2S2).compute();
        GroupElement sigma3T = this.pp.getG2GroupGenerator().getStructure().getNeutralElement();
        for (int i = 0; i < k_N; ++i) {
            sigma3T = sigma3T.op((Element)((GroupElementPlainText)messageHElements.get(i)).get().pow((Zn.ZnElement)sk.getExponentsU()[i].neg()));
        }
        sigma3T = sigma3T.op((Element)this.pp.getG2GroupGenerator());
        sigma3T = sigma3T.pow((Zn.ZnElement)r.inv()).compute();
        return new SPSAGHO11Signature(sigma1R, sigma2S, sigma3T);
    }

    @Override
    public Boolean verify(PlainText plainText, Signature signature, VerificationKey publicKey) {
        this.doMessageChecks(plainText);
        if (signature.getClass() != SPSAGHO11Signature.class) {
            throw new IllegalArgumentException("Not a valid signature for this scheme");
        }
        if (publicKey.getClass() != SPSAGHO11VerificationKey.class) {
            throw new IllegalArgumentException("Not a valid verification key for this scheme");
        }
        MessageBlock containerBlock = (MessageBlock)plainText;
        containerBlock = this.padMessageIfShort(containerBlock);
        MessageBlock messageGElements = (MessageBlock)containerBlock.get(0);
        MessageBlock messageHElements = (MessageBlock)containerBlock.get(1);
        SPSAGHO11Signature sigma = (SPSAGHO11Signature)signature;
        SPSAGHO11VerificationKey pk = (SPSAGHO11VerificationKey)publicKey;
        return this.evaluateFirstPPE(messageGElements, sigma, pk) && this.evaluateSecondPPE(messageHElements, sigma, pk);
    }

    private boolean evaluateFirstPPE(MessageBlock messageBlock, SPSAGHO11Signature sigma, SPSAGHO11VerificationKey pk) {
        BilinearMap bMap = this.pp.getBilinearMap();
        GroupElement lhs1 = bMap.apply(sigma.getGroup1ElementSigma1R(), pk.getGroup2ElementV());
        lhs1 = lhs1.op((Element)bMap.apply(sigma.getGroup1ElementSigma2S(), this.pp.getG2GroupGenerator()));
        GroupElement lhs2 = this.pp.getGT().getNeutralElement();
        for (int i = 0; i < messageBlock.length(); ++i) {
            lhs2 = lhs2.op((Element)bMap.apply(((GroupElementPlainText)messageBlock.get(i)).get(), pk.getGroup2ElementsW()[i]));
        }
        GroupElement lhs = lhs1.op((Element)lhs2);
        lhs.compute();
        GroupElement rhs = bMap.apply(this.pp.getG1GroupGenerator(), pk.getGroup2ElementZ());
        rhs.compute();
        return lhs.equals(rhs);
    }

    private boolean evaluateSecondPPE(MessageBlock messageBlock, SPSAGHO11Signature sigma, SPSAGHO11VerificationKey pk) {
        BilinearMap bMap = this.pp.getBilinearMap();
        GroupElement lhs1 = bMap.apply(sigma.getGroup1ElementSigma1R(), sigma.getGroup2ElementSigma3T());
        GroupElement lhs2 = this.pp.getGT().getNeutralElement();
        for (int i = 0; i < messageBlock.length(); ++i) {
            lhs2 = lhs2.op((Element)bMap.apply(pk.getGroup1ElementsU()[i], ((GroupElementPlainText)messageBlock.get(i)).get()));
        }
        GroupElement lhs = lhs1.op((Element)lhs2).compute();
        GroupElement rhs = bMap.apply(this.pp.getG1GroupGenerator(), this.pp.getG2GroupGenerator()).compute();
        return lhs.equals(rhs);
    }

    @Override
    public PlainText restorePlainText(Representation repr) {
        ListRepresentation list = (ListRepresentation)repr;
        Representation g1Elements = list.get(0);
        Representation g2Elements = list.get(1);
        MessageBlock g1 = new MessageBlock(g1Elements, r -> new GroupElementPlainText((Representation)r, this.pp.getG1GroupGenerator().getStructure()));
        MessageBlock g2 = new MessageBlock(g2Elements, r -> new GroupElementPlainText((Representation)r, this.pp.getG2GroupGenerator().getStructure()));
        return new MessageBlock(g1, g2);
    }

    @Override
    public Signature restoreSignature(Representation repr) {
        return new SPSAGHO11Signature(repr, this.pp.getG1GroupGenerator().getStructure(), this.pp.getG2GroupGenerator().getStructure());
    }

    @Override
    public SigningKey restoreSigningKey(Representation repr) {
        return new SPSAGHO11SigningKey(repr, this.pp.getZp());
    }

    @Override
    public VerificationKey restoreVerificationKey(Representation repr) {
        return new SPSAGHO11VerificationKey(this.pp.getG1GroupGenerator().getStructure(), this.pp.getG2GroupGenerator().getStructure(), repr);
    }

    @Override
    public MessageBlock mapToPlaintext(byte[] bytes, VerificationKey pk) {
        if (this.pp == null) {
            throw new NullPointerException("Number of messages is stored in public parameters but they are not set");
        }
        return this.mapToPlaintext(bytes, this.pp.getMessageLengths()[0]);
    }

    @Override
    public MessageBlock mapToPlaintext(byte[] bytes, SigningKey sk) {
        if (this.pp == null) {
            throw new NullPointerException("Number of messages is stored in public parameters but they are not set");
        }
        return this.mapToPlaintext(bytes, this.pp.getMessageLengths()[0]);
    }

    private MessageBlock mapToPlaintext(byte[] bytes, int messageBlockLength) {
        if (messageBlockLength == 0) {
            messageBlockLength = 1;
        }
        PlainText[] msgBlock = new GroupElementPlainText[messageBlockLength];
        msgBlock[0] = new GroupElementPlainText(this.pp.getG1GroupGenerator().pow((Zn.ZnElement)this.pp.getZp().injectiveValueOf(bytes)));
        for (int i = 1; i < msgBlock.length; ++i) {
            msgBlock[i] = new GroupElementPlainText(this.pp.getG1GroupGenerator());
        }
        return new MessageBlock(new MessageBlock(msgBlock), new MessageBlock(new PlainText[0]));
    }

    @Override
    public int getMaxNumberOfBytesForMapToPlaintext() {
        if (this.pp == null) {
            throw new NullPointerException("Number of messages is stored in public parameters but they are not set");
        }
        return (this.pp.getG1GroupGenerator().getStructure().size().bitLength() - 1) / 8;
    }

    public Representation getRepresentation() {
        return ReprUtil.serialize((Object)this);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        SPSAGHO11SignatureScheme that = (SPSAGHO11SignatureScheme)o;
        return Objects.equals(this.pp, that.pp);
    }

    public int hashCode() {
        return Objects.hash(this.pp);
    }

    private void doMessageChecks(PlainText plainText) throws IllegalArgumentException {
        if (!(plainText instanceof MessageBlock)) {
            throw new IllegalArgumentException(String.format("The plainText provided must be a MessageBlock, but was: %s", plainText.getClass().toString()));
        }
        MessageBlock msgBlock = (MessageBlock)plainText;
        if (msgBlock.length() != 2) {
            throw new IllegalArgumentException(String.format("The message provided must contain 2 inner MessageBlocks, but had: %d", msgBlock.length()));
        }
        for (int i = 0; i < 2; ++i) {
            if (msgBlock.get(i) instanceof MessageBlock) continue;
            throw new IllegalArgumentException(String.format("The message provided must contain 2 inner MessageBlocks, but element %d was not an instance of MessageBlock", i));
        }
        MessageBlock innerBlock1 = (MessageBlock)msgBlock.get(0);
        MessageBlock innerBlock2 = (MessageBlock)msgBlock.get(1);
        for (int blockID = 0; blockID < 2; ++blockID) {
            Group expectedGroup;
            MessageBlock innerBlock = blockID == 0 ? innerBlock1 : innerBlock2;
            int expectedLength = this.pp.messageLengths[blockID];
            Group group = expectedGroup = blockID == 0 ? this.pp.getG1GroupGenerator().getStructure() : this.pp.getG2GroupGenerator().getStructure();
            if (innerBlock.length() != expectedLength) {
                throw new IllegalArgumentException(String.format("length of %s message vector does not match public parameters expected %d, but was: %d", blockID == 0 ? "first" : "second", innerBlock1.length(), this.pp.messageLengths[0]));
            }
            for (int i = 0; i < innerBlock.length(); ++i) {
                if (!(innerBlock.get(i) instanceof GroupElementPlainText)) {
                    throw new IllegalArgumentException(String.format("The inner message blocks may only contain GroupElementPlainTexts, but element %d of inner block %d was of type: %s", i, blockID, ((PlainText)innerBlock.get(i)).getClass().toString()));
                }
                GroupElement groupElement = ((GroupElementPlainText)innerBlock.get(i)).get();
                if (groupElement.getStructure().equals(expectedGroup)) continue;
                throw new IllegalArgumentException(String.format("Element %d of inner message block %d does not match the expected group.  expected: %s, but was: %s", i, blockID, groupElement.getStructure().toString(), expectedGroup.toString()));
            }
        }
    }

    private MessageBlock padMessageIfShort(MessageBlock messageBlock) {
        MessageBlock firstInnerBlock = (MessageBlock)messageBlock.get(0);
        MessageBlock secondInnerBlock = (MessageBlock)messageBlock.get(1);
        GroupElement g1Neutral = this.pp.getG1GroupGenerator().getStructure().getNeutralElement();
        GroupElement g2Neutral = this.pp.getG2GroupGenerator().getStructure().getNeutralElement();
        if (this.pp.getMessageLengths()[0] < 1) {
            firstInnerBlock = new MessageBlock((Vector<? extends PlainText>)firstInnerBlock.pad(new GroupElementPlainText(g1Neutral), 1));
        }
        if (this.pp.getMessageLengths()[1] < 2) {
            secondInnerBlock = new MessageBlock((Vector<? extends PlainText>)secondInnerBlock.pad(new GroupElementPlainText(g2Neutral), 2));
        }
        return new MessageBlock(firstInnerBlock, secondInnerBlock);
    }
}

