package org.dronda.lib.jvm.crypto

import org.bouncycastle.crypto.engines.AESEngine
import org.bouncycastle.crypto.io.CipherInputStream
import org.bouncycastle.crypto.modes.GCMBlockCipher
import org.bouncycastle.crypto.params.AEADParameters
import org.bouncycastle.crypto.params.KeyParameter
import java.io.InputStream
import java.util.Base64
import javax.crypto.SecretKey
import javax.crypto.spec.SecretKeySpec

public class AesGcmEncryptor(
    private val secretKey: SecretKey
) : Encryptor {

    public companion object {
        private const val GCM_TAG_LENGTH_BITS = 128
        private const val GCM_TAG_LENGTH_BYTES = GCM_TAG_LENGTH_BITS / 8
        private const val GCM_IV_SIZE = 12

        public fun fromBase64(secretKey: String): AesGcmEncryptor {
            val key = Base64.getDecoder().decode(secretKey)
            return AesGcmEncryptor(
                SecretKeySpec(key, AlgorithmUtil.AES)
            )
        }
        public fun from(secretKey: ByteArray): AesGcmEncryptor {
            return AesGcmEncryptor(
                SecretKeySpec(secretKey, AlgorithmUtil.AES)
            )
        }

        /**
         * Returns the expected output size based off the parameters used in encryption.
         * @param plaintextLength - the size in bytes of the plaintext (unencrypted bytes).
         */
        public fun expectedOutputSize(plaintextLength: Int): Long = expectedOutputSize(plaintextLength.toLong())

        /**
         * Returns the expected output size based off the parameters used in encryption.
         * @param plaintextLength - the size in bytes of the plaintext (unencrypted bytes).
         */
        public fun expectedOutputSize(plaintextLength: Long): Long = plaintextLength + GCM_TAG_LENGTH_BYTES
    }
    override fun encryptPayload(payload: ByteArray): AESEncryptionData {
        val iv = createInitializationVector(GCM_IV_SIZE)
        val cipher = GCMBlockCipher(AESEngine())
        cipher.init(
            /* forEncryption = */ true,
            /* params = */ AEADParameters(
                /* key = */ KeyParameter(secretKey.encoded),
                /* macSize = */ GCM_TAG_LENGTH_BITS,
                /* nonce = */ iv
            )
        )

        return AESEncryptionData(
            value = CipherInputStream(payload.inputStream(), cipher).readAllBytes(),
            iv = iv,
        )
    }

    override fun encryptPayload(payload: InputStream): AESInputStreamData {
        val iv = createInitializationVector(GCM_IV_SIZE)
        val cipher = GCMBlockCipher(AESEngine())
        cipher.init(
            /* forEncryption = */ true,
            /* params = */ AEADParameters(
                /* key = */ KeyParameter(secretKey.encoded),
                /* macSize = */ GCM_TAG_LENGTH_BITS,
                /* nonce = */ iv
            )
        )

        return AESInputStreamData(
            value = CipherInputStream(payload, cipher),
            iv = iv,
        )
    }

    override fun decryptPayload(encryption: AESEncryptionData): ByteArray {
        val cipher = GCMBlockCipher(AESEngine())
        cipher.init(
            false,
            AEADParameters(
                KeyParameter(secretKey.encoded),
                GCM_TAG_LENGTH_BITS,
                encryption.iv,
            )
        )

        return CipherInputStream(encryption.value.inputStream(), cipher).readAllBytes()
    }

    override fun decryptPayload(encryption: AESInputStreamData): InputStream {
        val cipher = GCMBlockCipher(AESEngine())
        cipher.init(
            false,
            AEADParameters(
                KeyParameter(secretKey.encoded),
                GCM_TAG_LENGTH_BITS,
                encryption.iv,
            )
        )

        return CipherInputStream(encryption.value, cipher)
    }
}