package org.dronda.lib.jvm.encryption

import java.io.InputStream
import java.util.Base64
import javax.crypto.Cipher
import javax.crypto.CipherInputStream
import javax.crypto.SecretKey
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.SecretKeySpec

public class AesGcmEncryptor(
    private val secretKey: SecretKey
) : Encryptor {

    public companion object {
        private const val GCM_TAG_LENGTH_BITS = 128
        private const val GCM_TAG_LENGTH_BYTES = GCM_TAG_LENGTH_BITS / 8
        private const val GCM_IV_SIZE = 12

        public fun fromBase64(secretKey: String): AesGcmEncryptor {
            val key = Base64.getDecoder().decode(secretKey)
            return AesGcmEncryptor(
                SecretKeySpec(key, AlgorithmUtil.AES)
            )
        }
        public fun from(secretKey: ByteArray): AesGcmEncryptor {
            return AesGcmEncryptor(
                SecretKeySpec(secretKey, AlgorithmUtil.AES)
            )
        }

        /**
         * Returns the expected output size based off the parameters used in encryption.
         * @param plaintextLength - the size in bytes of the plaintext (unencrypted bytes).
         */
        public fun expectedOutputSize(plaintextLength: Int): Long = expectedOutputSize(plaintextLength.toLong())

        /**
         * Returns the expected output size based off the parameters used in encryption.
         * @param plaintextLength - the size in bytes of the plaintext (unencrypted bytes).
         */
        public fun expectedOutputSize(plaintextLength: Long): Long = plaintextLength + GCM_TAG_LENGTH_BYTES
    }
    override fun encryptPayload(payload: ByteArray): AESEncryptionData {
        val cipher = Cipher.getInstance(AlgorithmUtil.AES_GCM_ALGORITHM)
        val iv = createInitializationVector(GCM_IV_SIZE)
        cipher.init(Cipher.ENCRYPT_MODE, secretKey, GCMParameterSpec(GCM_TAG_LENGTH_BITS, iv))
        return AESEncryptionData(
            value = cipher.doFinal(payload),
            iv = iv
        )
    }

    override fun encryptPayload(payload: InputStream): AESInputStreamData {
        val cipher = Cipher.getInstance(AlgorithmUtil.AES_GCM_ALGORITHM)
        val iv = createInitializationVector(GCM_IV_SIZE)
        cipher.init(Cipher.ENCRYPT_MODE, secretKey, GCMParameterSpec(GCM_TAG_LENGTH_BITS, iv))
        return AESInputStreamData(
            value = CipherInputStream(payload, cipher),
            iv = iv,
        )
    }

    override fun decryptPayload(encryption: AESEncryptionData): ByteArray {
        val cipher = Cipher.getInstance(AlgorithmUtil.AES_GCM_ALGORITHM)
        cipher.init(Cipher.DECRYPT_MODE, secretKey, GCMParameterSpec(GCM_TAG_LENGTH_BITS, encryption.iv))
        return cipher.doFinal(encryption.value)
    }

    override fun decryptPayload(encryption: AESInputStreamData): InputStream {
        val cipher = Cipher.getInstance(AlgorithmUtil.AES_GCM_ALGORITHM)
        cipher.init(Cipher.DECRYPT_MODE, secretKey, GCMParameterSpec(GCM_TAG_LENGTH_BITS, encryption.iv))
        return CipherInputStream(encryption.value, cipher)
    }
}