package in.juspay.security;

import org.apache.commons.codec.binary.Base64;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.security.spec.X509EncodedKeySpec;

public class Keys {
    public static PublicKey readPublicKey(String key) throws JuspayCryptoException {
        final String RSAPUBK_BEGIN = "-----BEGIN RSA PUBLIC KEY-----";
        final String RSAPUBK_END = "-----END RSA PUBLIC KEY-----";
        final String PUBK_BEGIN = "-----BEGIN PUBLIC KEY-----";
        final String PUBK_END = "-----END PUBLIC KEY-----";
        final String REPLACEMENT = "";

        String keyContent = key.replaceAll("\\n", REPLACEMENT)
                .replace(RSAPUBK_BEGIN, REPLACEMENT)
                .replace(RSAPUBK_END, REPLACEMENT)
                .replace(PUBK_BEGIN, REPLACEMENT)
                .replace(PUBK_END, REPLACEMENT)
                .replace("\r\n",REPLACEMENT);
        byte[] keyBytes = Base64.decodeBase64(keyContent);
        try {
            KeyFactory kf = KeyFactory.getInstance("RSA");
            return kf.generatePublic(new X509EncodedKeySpec(keyBytes));
        } catch (Exception e) {
            throw new JuspayCryptoException("please check public key");
        }
    }

    public static PrivateKey readPrivateKey(String key) throws JuspayCryptoException {
        final String PKCS1_BEGIN = "-----BEGIN RSA PRIVATE KEY-----";
        final String PKCS1_END = "-----END RSA PRIVATE KEY-----";
        final String PKCS8_BEGIN = "-----BEGIN PRIVATE KEY-----";
        final String PKCS8_END = "-----END PRIVATE KEY-----";
        final String REPLACEMENT = "";

        String keyContent = key.replaceAll("\\n", REPLACEMENT)
                .replace(PKCS1_BEGIN, REPLACEMENT)
                .replace(PKCS1_END, REPLACEMENT)
                .replace(PKCS8_END, REPLACEMENT)
                .replace(PKCS8_BEGIN, REPLACEMENT)
                .replace("\r\n",REPLACEMENT);
        byte[] keyBytes = Base64.decodeBase64(keyContent);

        KeyFactory factory;
        try {
            factory = KeyFactory.getInstance("RSA");
        } catch (NoSuchAlgorithmException e) {
            throw new JuspayCryptoException("No such algorithm found:- " + e.getMessage(), e);
        }

        if (key.contains(PKCS1_BEGIN)) {
            RSAPrivateCrtKeySpec keySpec = getRSAPrivateKeySpec(keyBytes);
            try {
                return factory.generatePrivate(keySpec);
            } catch (InvalidKeySpecException e) {
                throw new JuspayCryptoException("please check pkcs1 key: " + e.getMessage());
            }
        } else if (key.contains(PKCS8_BEGIN)) {
            try {
                PKCS8EncodedKeySpec keySpecPv = new PKCS8EncodedKeySpec(keyBytes);
                return  factory.generatePrivate(keySpecPv);
            } catch (InvalidKeySpecException e) {
                throw new JuspayCryptoException("please check pkcs8 key:- " + e.getMessage(), e);
            }
        } else {
            throw new JuspayCryptoException("key format not supported");
        }
    }

    protected static RSAPrivateCrtKeySpec getRSAPrivateKeySpec(byte[] keyBytes) throws JuspayCryptoException {
        try {
            ByteArrayInputStream in = new ByteArrayInputStream(keyBytes);
            initiateSequence(in);
            BigInteger ignoredVersion = readDERAsInteger(in);
            BigInteger modulus = readDERAsInteger(in);
            BigInteger publicExp = readDERAsInteger(in);
            BigInteger privateExp = readDERAsInteger(in);
            BigInteger prime1 = readDERAsInteger(in);
            BigInteger prime2 = readDERAsInteger(in);
            BigInteger exp1 = readDERAsInteger(in);
            BigInteger exp2 = readDERAsInteger(in);
            BigInteger certificateCoefficient = readDERAsInteger(in);

            if (in.read() != -1) {
                throw new JuspayCryptoException("private key too long, if not please check for spaces unwanted chars etc");
            }

            return new RSAPrivateCrtKeySpec(
                    modulus, publicExp, privateExp, prime1, prime2,
                    exp1, exp2, certificateCoefficient);
        } catch (IOException e) {
            throw new JuspayCryptoException(e.getMessage(), e);
        }
    }

    private static BigInteger readDERAsInteger(ByteArrayInputStream in) throws JuspayCryptoException, IOException {
        final byte INTEGER = 0x02;

        int tag = getTag(in);
        int length = getLength(in);
        byte[] value = new byte[length];
        int n = in.read(value);
        if (n < length) {
            throw new JuspayCryptoException("key is too short");
        }

        if (getType(tag) != INTEGER) {
            throw new JuspayCryptoException("parsed value is not integer");
        }

        return new BigInteger(value);
    }

    private static void initiateSequence(ByteArrayInputStream in) throws JuspayCryptoException, IOException {
        final byte DEM_SEQUENCE = 0x10;
        final byte PRIMITIVE_SEQUENCE = 0x20;

        int tag = getTag(in);
        int length = getLength(in);

        if (in.available() < length) {
            throw new JuspayCryptoException("key is too short");
        }

        if (getType(tag) != DEM_SEQUENCE) {
            throw new JuspayCryptoException("cannot initiate key sequence, key is invalid");
        }

        if (!((tag & PRIMITIVE_SEQUENCE) == PRIMITIVE_SEQUENCE)) {
            throw new JuspayCryptoException("key is invalid");
        }
    }

    private static int getTag(ByteArrayInputStream in) throws JuspayCryptoException {
        int tag = in.read();
        if (tag == -1) {
            throw new JuspayCryptoException("key is too short");
        }
        return tag;
    }

    private static int getType(int tag) {
        return tag & 0x1F;
    }

    private static int getLength(ByteArrayInputStream in) throws JuspayCryptoException, IOException {
        final byte SHORT_BYTE_LENGTH = ~0x7F;

        int suffix = in.read();
        if (suffix == -1) {
            throw new JuspayCryptoException("key is too short, cannot get length");
        }

        if ((suffix & SHORT_BYTE_LENGTH) == 0) {
            return suffix;
        }

        int size = suffix & ~SHORT_BYTE_LENGTH;

        if (suffix >= 0xFF || size > 4) {
            throw new JuspayCryptoException("der parsing not supported, please check your private key");
        }

        byte[] bytes = new byte[size];
        int n = in.read(bytes);
        if (n < size) {
            throw new JuspayCryptoException("key is too short");
        }

        return new BigInteger(1, bytes).intValue();
    }
}
