package cool.doudou.doudada.cipher.algorithm.sm;

import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.math.ec.ECPoint;

import java.math.BigInteger;

/**
 * Cipher
 *
 * @author jiangcs
 * @since 2022/08/31
 */
public class Cipher {
    private int ct;
    private ECPoint ecPoint;
    private SM3Digest sm3DigestP2;
    private SM3Digest sm3DigestC3;
    private final byte[] key;
    private byte keyOff;

    public Cipher() {
        this.ct = 1;
        this.key = new byte[32];
        this.keyOff = 0;
    }

    private void reset() {
        this.sm3DigestP2 = new SM3Digest();
        this.sm3DigestC3 = new SM3Digest();

        byte[] xBytes = byteConvert32Bytes(ecPoint.getX().toBigInteger());
        this.sm3DigestP2.update(xBytes, 0, xBytes.length);
        this.sm3DigestC3.update(xBytes, 0, xBytes.length);

        byte[] yBytes = byteConvert32Bytes(ecPoint.getY().toBigInteger());
        this.sm3DigestP2.update(yBytes, 0, yBytes.length);
        this.ct = 1;
        nextKey();
    }

    private void nextKey() {
        SM3Digest sm3Digest = new SM3Digest(this.sm3DigestP2);
        sm3Digest.update((byte) (ct >> 24 & 0xff));
        sm3Digest.update((byte) (ct >> 16 & 0xff));
        sm3Digest.update((byte) (ct >> 8 & 0xff));
        sm3Digest.update((byte) (ct & 0xff));
        sm3Digest.doFinal(key, 0);
        this.keyOff = 0;
        this.ct++;
    }

    public ECPoint initEnc(Sm2 sm2, ECPoint userKey) {
        AsymmetricCipherKeyPair asymmetricCipherKeyPair = sm2.eccKeyPairGenerator.generateKeyPair();
        ECPrivateKeyParameters ecPrivateKeyParameters = (ECPrivateKeyParameters) asymmetricCipherKeyPair.getPrivate();
        ECPublicKeyParameters ecPublicKeyParameters = (ECPublicKeyParameters) asymmetricCipherKeyPair.getPublic();
        BigInteger d = ecPrivateKeyParameters.getD();
        ECPoint q = ecPublicKeyParameters.getQ();
        this.ecPoint = userKey.multiply(d);
        reset();
        return q;
    }

    public void encrypt(byte[] data) {
        this.sm3DigestC3.update(data, 0, data.length);
        for (int i = 0; i < data.length; i++) {
            if (keyOff == key.length) {
                nextKey();
            }
            data[i] ^= key[keyOff++];
        }
    }

    public void initDec(BigInteger userD, ECPoint ecPointC1) {
        this.ecPoint = ecPointC1.multiply(userD);
        reset();
    }

    public void decrypt(byte[] data) {
        for (int i = 0; i < data.length; i++) {
            if (keyOff == key.length) {
                nextKey();
            }
            data[i] ^= key[keyOff++];
        }

        this.sm3DigestC3.update(data, 0, data.length);
    }

    public void doFinal(byte[] outBytes) {
        byte[] yBytes = byteConvert32Bytes(ecPoint.getY().toBigInteger());
        this.sm3DigestC3.update(yBytes, 0, yBytes.length);
        this.sm3DigestC3.doFinal(outBytes, 0);
        reset();
    }

    /**
     * 大数字转换字节流（字节数组）型数据
     *
     * @param n 数字
     * @return 字节数组
     */
    private static byte[] byteConvert32Bytes(BigInteger n) {
        byte[] bytes;
        if (n == null) {
            return null;
        }

        if (n.toByteArray().length == 33) {
            bytes = new byte[32];
            System.arraycopy(n.toByteArray(), 1, bytes, 0, 32);
        } else if (n.toByteArray().length == 32) {
            bytes = n.toByteArray();
        } else {
            bytes = new byte[32];
            for (int i = 0; i < 32 - n.toByteArray().length; i++) {
                bytes[i] = 0;
            }
            System.arraycopy(n.toByteArray(), 0, bytes, 32 - n.toByteArray().length, n.toByteArray().length);
        }
        return bytes;
    }
}
