/*
 * Decompiled with CFR 0.152.
 */
package org.cryptimeleon.craco.sig.sps.kpw15;

import java.util.Arrays;
import java.util.Objects;
import java.util.stream.IntStream;
import org.cryptimeleon.craco.common.plaintexts.GroupElementPlainText;
import org.cryptimeleon.craco.common.plaintexts.MessageBlock;
import org.cryptimeleon.craco.common.plaintexts.PlainText;
import org.cryptimeleon.craco.sig.MultiMessageStructurePreservingSignatureScheme;
import org.cryptimeleon.craco.sig.Signature;
import org.cryptimeleon.craco.sig.SignatureKeyPair;
import org.cryptimeleon.craco.sig.SigningKey;
import org.cryptimeleon.craco.sig.VerificationKey;
import org.cryptimeleon.craco.sig.sps.kpw15.SPSKPW15PublicParameters;
import org.cryptimeleon.craco.sig.sps.kpw15.SPSKPW15Signature;
import org.cryptimeleon.craco.sig.sps.kpw15.SPSKPW15SigningKey;
import org.cryptimeleon.craco.sig.sps.kpw15.SPSKPW15VerificationKey;
import org.cryptimeleon.math.serialization.Representation;
import org.cryptimeleon.math.serialization.annotations.ReprUtil;
import org.cryptimeleon.math.serialization.annotations.Represented;
import org.cryptimeleon.math.structures.Element;
import org.cryptimeleon.math.structures.cartesian.Vector;
import org.cryptimeleon.math.structures.groups.GroupElement;
import org.cryptimeleon.math.structures.groups.cartesian.GroupElementVector;
import org.cryptimeleon.math.structures.groups.elliptic.BilinearMap;
import org.cryptimeleon.math.structures.rings.zn.Zn;
import org.cryptimeleon.math.structures.rings.zn.Zp;

public class SPSKPW15SignatureScheme
implements MultiMessageStructurePreservingSignatureScheme {
    @Represented
    SPSKPW15PublicParameters pp;

    public SPSKPW15SignatureScheme() {
    }

    public SPSKPW15SignatureScheme(SPSKPW15PublicParameters pp) {
        this.pp = pp;
    }

    public SPSKPW15SignatureScheme(Representation repr) {
        new ReprUtil((Object)this).deserialize(repr);
    }

    public SignatureKeyPair<SPSKPW15VerificationKey, SPSKPW15SigningKey> generateKeyPair(int numberOfMessages) {
        Zp zp = this.pp.getZp();
        if (numberOfMessages < 1) {
            throw new IllegalArgumentException("The signature scheme KPW15 expects to sign at least 1 element");
        }
        Zp.ZpElement[] A = new Zp.ZpElement[]{zp.getOneElement(), zp.getUniformlyRandomElement()};
        Zp.ZpElement[] B = new Zp.ZpElement[]{zp.getOneElement(), zp.getUniformlyRandomElement()};
        Zp.ZpElement[] K = (Zp.ZpElement[])IntStream.range(0, (numberOfMessages + 1) * 2).mapToObj(x -> zp.getUniformlyRandomElement()).toArray(Zp.ZpElement[]::new);
        Zp.ZpElement[] K0 = (Zp.ZpElement[])IntStream.range(0, 4).mapToObj(x -> zp.getUniformlyRandomElement()).toArray(Zp.ZpElement[]::new);
        Zp.ZpElement[] K1 = (Zp.ZpElement[])IntStream.range(0, 4).mapToObj(x -> zp.getUniformlyRandomElement()).toArray(Zp.ZpElement[]::new);
        Object[] C = MatrixUtility.matrixMul(K, numberOfMessages + 1, 2, A, 2, 1);
        Object[] C0 = MatrixUtility.matrixMul(K0, 2, 2, A, 2, 1);
        Object[] C1 = MatrixUtility.matrixMul(K1, 2, 2, A, 2, 1);
        Object[] P0 = MatrixUtility.matrixMul(B, 1, 2, K0, 2, 2);
        Object[] P1 = MatrixUtility.matrixMul(B, 1, 2, K1, 2, 2);
        SPSKPW15SigningKey sk = new SPSKPW15SigningKey(K, (GroupElement[])this.pp.getG1GroupGenerator().pow(new Vector(P0)).compute().stream().toArray(GroupElement[]::new), (GroupElement[])this.pp.getG1GroupGenerator().pow(new Vector(P1)).compute().stream().toArray(GroupElement[]::new), this.pp.getG1GroupGenerator().pow((Zn.ZnElement)B[1]).compute());
        SPSKPW15VerificationKey vk = new SPSKPW15VerificationKey((GroupElement[])this.pp.getG2GroupGenerator().pow(new Vector(C0)).compute().stream().toArray(GroupElement[]::new), (GroupElement[])this.pp.getG2GroupGenerator().pow(new Vector(C1)).compute().stream().toArray(GroupElement[]::new), (GroupElement[])this.pp.getG2GroupGenerator().pow(new Vector(C)).compute().stream().toArray(GroupElement[]::new), this.pp.getG2GroupGenerator().pow((Zn.ZnElement)A[1]).compute());
        return new SignatureKeyPair<SPSKPW15VerificationKey, SPSKPW15SigningKey>(vk, sk);
    }

    @Override
    public Signature sign(PlainText plainText, SigningKey secretKey) {
        if (plainText instanceof GroupElementPlainText) {
            plainText = new MessageBlock(plainText);
        }
        this.doMessageChecks(plainText);
        MessageBlock messageBlock = (MessageBlock)plainText;
        messageBlock.prepend(new GroupElementPlainText(this.pp.getG1GroupGenerator()));
        if (!(secretKey instanceof SPSKPW15SigningKey)) {
            throw new IllegalArgumentException("Not a valid signing key for this scheme");
        }
        SPSKPW15SigningKey sk = (SPSKPW15SigningKey)secretKey;
        Zp.ZpElement r0 = this.pp.getZp().getUniformlyRandomElement();
        Zp.ZpElement r1 = this.pp.getZp().getUniformlyRandomElement();
        GroupElement[] message = new GroupElement[messageBlock.length() + 1];
        message[0] = this.pp.getG1GroupGenerator();
        for (int i = 1; i <= messageBlock.length(); ++i) {
            message[i] = ((GroupElementPlainText)messageBlock.get(i - 1)).get();
        }
        GroupElement[] sigma1lhs = MatrixUtility.calculateSigma1MatrixMxK(message, sk.getK());
        GroupElement[] sigma1rhsInner = (GroupElement[])Arrays.stream(sk.getP1()).map(x -> x.pow((Zn.ZnElement)r1).compute()).toArray(GroupElement[]::new);
        for (int i = 0; i < sigma1rhsInner.length; ++i) {
            sigma1rhsInner[i] = sk.getP0()[i].op((Element)sigma1rhsInner[i]);
            sigma1rhsInner[i] = sigma1rhsInner[i].pow((Zn.ZnElement)r0);
            sigma1rhsInner[i].compute();
        }
        GroupElement[] sigma1 = new GroupElement[sigma1lhs.length];
        for (int i = 0; i < sigma1.length; ++i) {
            sigma1[i] = sigma1lhs[i].op((Element)sigma1rhsInner[i]).compute();
        }
        GroupElement[] sigma2 = (GroupElement[])new Vector((Object[])new GroupElement[]{this.pp.getG1GroupGenerator(), sk.getB()}).stream().map(x -> x.pow((Zn.ZnElement)r0).compute()).toArray(GroupElement[]::new);
        GroupElement[] sigma3 = (GroupElement[])Arrays.stream(sigma2).map(x -> x.pow((Zn.ZnElement)r1)).toArray(GroupElement[]::new);
        GroupElement sigma4 = this.pp.getG2GroupGenerator().pow((Zn.ZnElement)r1).compute();
        return new SPSKPW15Signature(sigma1, sigma2, sigma3, sigma4);
    }

    private boolean checkSigma1(GroupElement[] sigma1, GroupElement[] paddedMessage, Zp.ZpElement[] K, Zp.ZpElement r0, Zp.ZpElement r1, GroupElement[] P0, GroupElement[] P1) {
        GroupElement[] lhs = new GroupElement[]{paddedMessage[0].pow((Zn.ZnElement)K[0]).op((Element)paddedMessage[1].pow((Zn.ZnElement)K[1])).compute(), paddedMessage[0].pow((Zn.ZnElement)K[2]).op((Element)paddedMessage[1].pow((Zn.ZnElement)K[3])).compute()};
        GroupElement[] r1P1 = new GroupElement[]{P1[0].pow((Zn.ZnElement)r1).compute(), P1[1].pow((Zn.ZnElement)r1).compute()};
        GroupElement[] P0r1P1 = new GroupElement[]{P0[0].op((Element)r1P1[0]).compute(), P0[1].op((Element)r1P1[1]).compute()};
        GroupElement[] rhs = new GroupElement[]{P0r1P1[0].pow((Zn.ZnElement)r0).compute(), P0r1P1[1].pow((Zn.ZnElement)r0).compute()};
        GroupElement[] checkSig = new GroupElement[]{lhs[0].op((Element)rhs[0]).compute(), lhs[1].op((Element)rhs[1]).compute()};
        return sigma1[0].equals(checkSig[0]) && sigma1[1].equals(checkSig[1]);
    }

    @Override
    public Boolean verify(PlainText plainText, Signature signature, VerificationKey publicKey) {
        GroupElementVector sigma3;
        GroupElementVector sigma2;
        if (plainText instanceof GroupElementPlainText) {
            plainText = new MessageBlock(plainText);
        }
        this.doMessageChecks(plainText);
        if (!(signature instanceof SPSKPW15Signature)) {
            throw new IllegalArgumentException("Not a valid signature for this scheme");
        }
        if (!(publicKey instanceof SPSKPW15VerificationKey)) {
            throw new IllegalArgumentException("Not a valid verification key for this scheme");
        }
        MessageBlock messageBlock = (MessageBlock)plainText;
        messageBlock = this.padMessage(messageBlock);
        SPSKPW15Signature sigma = (SPSKPW15Signature)signature;
        SPSKPW15VerificationKey pk = (SPSKPW15VerificationKey)publicKey;
        GroupElementVector C0 = new GroupElementVector(pk.getC0());
        GroupElementVector C1 = new GroupElementVector(pk.getC1());
        GroupElementVector C = new GroupElementVector(pk.getC());
        GroupElementVector sigma1 = new GroupElementVector(sigma.getGroup1ElementSigma1R());
        return this.evaluateFirstPPE(sigma1, sigma2 = new GroupElementVector(sigma.getGroup1ElementSigma2S()), sigma3 = new GroupElementVector(sigma.getGroup1ElementSigma3T()), messageBlock, C, C0, C1, pk.getA()) && this.evaluateSecondPPE(sigma2, sigma.getGroup2ElementSigma4U(), sigma3);
    }

    private boolean evaluateFirstPPE(GroupElementVector sigma1, GroupElementVector sigma2, GroupElementVector sigma3, MessageBlock paddedMessage, GroupElementVector C, GroupElementVector C0, GroupElementVector C1, GroupElement A) {
        BilinearMap bMap = this.pp.getBilinearMap();
        GroupElement[] message = (GroupElement[])paddedMessage.stream().map(x -> ((GroupElementPlainText)x).get()).toArray(GroupElement[]::new);
        GroupElementVector ppe1lhs = MatrixUtility.matrixApplyMap(bMap, sigma1, 1, 2, new GroupElementVector(new GroupElement[]{this.pp.getG2GroupGenerator(), A}), 2, 1).compute();
        GroupElementVector ppe1rhs1 = MatrixUtility.matrixApplyMap(bMap, new GroupElementVector(message), 1, message.length, C, message.length, 1);
        GroupElementVector ppe1rhs2 = MatrixUtility.matrixApplyMap(bMap, sigma2, 1, 2, C0, 2, 1);
        GroupElementVector ppe1rhs3 = MatrixUtility.matrixApplyMap(bMap, sigma3, 1, 2, C1, 2, 1);
        GroupElementVector ppe1rhs = ppe1rhs1.op((Vector)ppe1rhs2).op((Vector)ppe1rhs3).compute();
        return ppe1lhs.equals((Object)ppe1rhs);
    }

    private boolean evaluateSecondPPE(GroupElementVector sigma2, GroupElement sigma4, GroupElementVector sigma3) {
        BilinearMap bMap = this.pp.getBilinearMap();
        GroupElementVector ppe2lhs = MatrixUtility.matrixApplyMap(bMap, sigma2, 1, 2, new GroupElementVector(new GroupElement[]{sigma4, sigma4}), 2, 1);
        GroupElementVector ppe2rhs = MatrixUtility.matrixApplyMap(bMap, sigma3, 1, 2, new GroupElementVector(new GroupElement[]{this.pp.getG2GroupGenerator(), this.pp.getG2GroupGenerator()}), 2, 1);
        return ppe2lhs.equals((Object)ppe2rhs);
    }

    private MessageBlock padMessage(MessageBlock messageBlock) {
        return new MessageBlock((Vector<? extends PlainText>)messageBlock.prepend(new GroupElementPlainText(this.pp.getG1GroupGenerator())));
    }

    private void doMessageChecks(PlainText plainText) {
        if (!(plainText instanceof MessageBlock)) {
            throw new IllegalArgumentException("The scheme requires its messages to be GroupElements");
        }
        MessageBlock messageBlock = (MessageBlock)plainText;
        if (messageBlock.length() != this.pp.messageLength.intValue()) {
            throw new IllegalArgumentException(String.format("The scheme expected a message of length %d, but the size was: %d", this.pp.messageLength, messageBlock.length()));
        }
        for (int i = 0; i < messageBlock.length(); ++i) {
            if (!(messageBlock.get(i) instanceof GroupElementPlainText)) {
                throw new IllegalArgumentException(String.format("The scheme requires its messages to be GroupElements, but element %d was of type: %s", i, ((PlainText)messageBlock.get(i)).getClass().toString()));
            }
            GroupElementPlainText groupElementPT = (GroupElementPlainText)messageBlock.get(i);
            if (groupElementPT.get().getStructure().equals(this.pp.getG1GroupGenerator().getStructure())) continue;
            throw new IllegalArgumentException(String.format("Expected message to be in G_1, but element %d was in: %s", i, groupElementPT.get().getStructure().toString()));
        }
    }

    @Override
    public MessageBlock restorePlainText(Representation repr) {
        return new MessageBlock(repr, r -> new GroupElementPlainText((Representation)r, this.pp.getG1GroupGenerator().getStructure()));
    }

    @Override
    public Signature restoreSignature(Representation repr) {
        return new SPSKPW15Signature(repr, this.pp.getG1GroupGenerator().getStructure(), this.pp.getG2GroupGenerator().getStructure());
    }

    @Override
    public SigningKey restoreSigningKey(Representation repr) {
        return new SPSKPW15SigningKey(repr, this.pp.getZp(), this.pp.getG1GroupGenerator().getStructure());
    }

    @Override
    public VerificationKey restoreVerificationKey(Representation repr) {
        return new SPSKPW15VerificationKey(this.pp.getG1GroupGenerator().getStructure(), this.pp.getG2GroupGenerator().getStructure(), repr);
    }

    @Override
    public PlainText mapToPlaintext(byte[] bytes, VerificationKey pk) {
        if (this.pp == null) {
            throw new NullPointerException("Number of messages is stored in public parameters but they are not set");
        }
        return this.mapToPlaintext(bytes, this.pp.messageLength);
    }

    @Override
    public PlainText mapToPlaintext(byte[] bytes, SigningKey sk) {
        if (this.pp == null) {
            throw new NullPointerException("Number of messages is stored in public parameters but they are not set");
        }
        return this.mapToPlaintext(bytes, this.pp.messageLength);
    }

    private MessageBlock mapToPlaintext(byte[] bytes, int messageLength) {
        PlainText[] msgBlock = new GroupElementPlainText[messageLength];
        msgBlock[0] = new GroupElementPlainText(this.pp.getG1GroupGenerator().pow((Zn.ZnElement)this.pp.getZp().injectiveValueOf(bytes)));
        for (int i = 1; i < msgBlock.length; ++i) {
            msgBlock[i] = new GroupElementPlainText(this.pp.getG1GroupGenerator());
        }
        return new MessageBlock(msgBlock);
    }

    @Override
    public int getMaxNumberOfBytesForMapToPlaintext() {
        return (this.pp.getG1GroupGenerator().getStructure().size().bitLength() - 1) / 8;
    }

    public Representation getRepresentation() {
        return ReprUtil.serialize((Object)this);
    }

    public int hashCode() {
        int prime = 41;
        int result = 1;
        result = 41 * result + (this.pp == null ? 0 : this.pp.hashCode());
        return result;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        SPSKPW15SignatureScheme that = (SPSKPW15SignatureScheme)o;
        return Objects.equals(this.pp, that.pp);
    }

    static class MatrixUtility {
        MatrixUtility() {
        }

        public static Zp.ZpElement[] matrixMul(Zp.ZpElement[] A, int rowsA, int columnsA, Zp.ZpElement[] B, int rowsB, int columnsB) {
            if (A.length != rowsA * columnsA || B.length != rowsB * columnsB) {
                throw new IllegalArgumentException("The given vector's length does not match its matrix dimensions");
            }
            if (columnsA != rowsB) {
                throw new IllegalArgumentException(String.format("function is only defined for matrices where columns_A == rows_B : got %d vs. %d", columnsA, rowsB));
            }
            Zp.ZpElement[] multiplied = new Zp.ZpElement[rowsA * columnsB];
            for (int r = 1; r <= rowsA; ++r) {
                for (int c = 1; c <= columnsB; ++c) {
                    Zp.ZpElement value = B[0].getStructure().getZeroElement();
                    for (int i = 1; i <= columnsA; ++i) {
                        value = value.add((Element)A[MatrixUtility.getMatrixIndex(rowsA, columnsA, r, i)].mul((Element)B[MatrixUtility.getMatrixIndex(rowsB, columnsB, i, c)]));
                    }
                    multiplied[MatrixUtility.getMatrixIndex((int)rowsA, (int)columnsB, (int)r, (int)c)] = value;
                }
            }
            return multiplied;
        }

        public static GroupElementVector matrixApplyMap(BilinearMap bMap, GroupElementVector A, int rowsA, int columnsA, GroupElementVector B, int rowsB, int columnsB) {
            if (A.length() != rowsA * columnsA || B.length() != rowsB * columnsB) {
                throw new IllegalArgumentException("The given vectors length does not match its matrix dimensions");
            }
            if (columnsA != rowsB) {
                throw new IllegalArgumentException(String.format("function is only defined for matrices where columns_A == rows_B : got %d x %d", columnsA, rowsB));
            }
            GroupElement[] multiplied = new GroupElement[rowsA * columnsB];
            for (int r = 1; r <= rowsA; ++r) {
                for (int c = 1; c <= columnsB; ++c) {
                    GroupElement value = bMap.getGT().getNeutralElement();
                    for (int i = 1; i <= columnsA; ++i) {
                        value = value.op((Element)bMap.apply((GroupElement)A.get(MatrixUtility.getMatrixIndex(rowsA, columnsA, r, i)), (GroupElement)B.get(MatrixUtility.getMatrixIndex(rowsB, columnsB, i, c))));
                    }
                    value.compute();
                    multiplied[MatrixUtility.getMatrixIndex((int)rowsA, (int)columnsB, (int)r, (int)c)] = value;
                }
            }
            return new GroupElementVector(multiplied);
        }

        public static int getMatrixIndex(int rows, int columns, int row, int column) {
            return rows * (column - 1) + (row - 1);
        }

        public static GroupElement[] calculateSigma1MatrixMxK(GroupElement[] message, Zp.ZpElement[] K) {
            int rows = 1;
            int columns = 2;
            GroupElement[] multiplied = new GroupElement[rows * columns];
            for (int c = 1; c <= columns; ++c) {
                GroupElement value = message[0].getStructure().getNeutralElement();
                for (int i = 1; i <= message.length; ++i) {
                    Zp.ZpElement exponentK = K[MatrixUtility.getMatrixIndex(message.length, 2, i, c)];
                    GroupElement messageElement = message[i - 1];
                    value = value.op((Element)messageElement.pow((Zn.ZnElement)exponentK));
                }
                value.compute();
                multiplied[c - 1] = value;
            }
            return multiplied;
        }
    }
}

