/*
 * Copyright (C) 2020-2024, Xie YuBin
 * The GNU Free Documentation License covers this file. The original version
 * of this license can be found at http://www.gnu.org/licenses/gfdl.html.
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Free Documentation License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Free Documentation License for more details.
 *
 * You should have received a copy of the GNU Free Documentation License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

package cn.sinozg.applet.common.utils;

import cn.sinozg.applet.common.core.model.StandardSm2Engine;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.binary.Hex;
import org.bouncycastle.asn1.gm.GMObjectIdentifiers;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.engines.SM2Engine;
import org.bouncycastle.crypto.macs.HMac;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.params.ParametersWithRandom;
import org.bouncycastle.jcajce.provider.asymmetric.util.ECUtil;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.Security;
import java.security.Signature;
import java.security.SignatureException;
import java.security.spec.ECGenParameterSpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays;

/**
 * @author xieyubin
 * @Description
 * @Copyright Copyright (c) 2024
 * @since 2024-12-03 10:35
 */
public class SmUtil {

    /** EC */
    private static final String EC = "EC";

    private static final String STD_NAME = "sm2p256v1";

    public static final String ALGORITHM_SM4 = "SM4";
    /** SM4算法目前只支持128位（即密钥16字节） */
    private static final int DEFAULT_KEY_SIZE = 128;
    /**
     * 加密算法/分组加密模式/分组填充方式
     * PKCS5Padding-以8个字节为一组进行分组加密
     * 定义分组加密模式使用：PKCS5Padding
     */
    public static final String ALGORITHM_NAME_ECB_PADDING5 = "SM4/ECB/PKCS5Padding";
    public static final String ALGORITHM_NAME_ECB_PADDING7 = "SM4/ECB/PKCS7Padding";

    private static final BouncyCastleProvider BOUNCY = new BouncyCastleProvider();

    private static final Logger LOG = LoggerFactory.getLogger(SmUtil.class);

    static {
        Security.addProvider(new BouncyCastleProvider());
    }

    public static void initProvider(){
    }

    /**
     * 加签
     * @param plainText 内容
     * @param privateKey 私钥
     * @return 加签
     */
    public static String signSm2(String plainText, String privateKey) {
        try {
            // 获取椭圆曲线KEY生成器
            Signature rsaSignature = Signature.getInstance(GMObjectIdentifiers.sm2sign_with_sm3.toString(), BOUNCY);
            PrivateKey key = sm2Key(privateKey, false, true);
            rsaSignature.initSign(key);
            rsaSignature.update(plainText.getBytes());
            byte[] signed = rsaSignature.sign();
            return Base64.encodeBase64String(signed);
        } catch (NoSuchAlgorithmException | InvalidKeySpecException | InvalidKeyException | SignatureException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 验签
     * @param plainText 内容
     * @param signature 签名
     * @param publicKey 公钥
     * @return 验签结果
     */
    public static boolean verifySm2(String plainText, String signature, String publicKey) {
        try {
            // 初始化为验签状态
            Signature sign = Signature.getInstance(GMObjectIdentifiers.sm2sign_with_sm3.toString(), BOUNCY);
            PublicKey pk = sm2Key(publicKey, true, true);
            sign.initVerify(pk);
            sign.update(Hex.decodeHex(plainText.toCharArray()));
            return sign.verify(Hex.decodeHex(signature.toCharArray()));
        } catch (NoSuchAlgorithmException | InvalidKeySpecException | InvalidKeyException | SignatureException e) {
            throw new RuntimeException(e);
        } catch (IllegalArgumentException | DecoderException e) {
            LOG.error("验签失败", e);
            return false;
        }
    }

    /**
     * 加密
     * @param plainText 内容
     * @return 加密谢谢
     */
    public static String encryptSm2(String plainText, String publicKey) {
        byte[] decryptData = sm2(plainText, publicKey, true);
        return Base64.encodeBase64String(decryptData);
    }

    /**
     * 解密
     * @param input 密文
     * @param privateKey 私钥
     * @return 解密信息
     */
    public static String decryptSm2(String input, String privateKey) {
        byte[] decryptData = sm2(input, privateKey, false);
        return new String(decryptData);
    }

    /**
     * SM2算法生成密钥对
     * @return 密钥对信息
     */
    public static KeyPair keyPairSm2() {
        try {
            final ECGenParameterSpec sm2Spec = new ECGenParameterSpec(STD_NAME);
            // 获取一个椭圆曲线类型的密钥对生成器
            final KeyPairGenerator kpg = KeyPairGenerator.getInstance(EC, BOUNCY);
            SecureRandom random = new SecureRandom();
            // 使用SM2的算法区域初始化密钥生成器
            kpg.initialize(sm2Spec, random);
            // 获取密钥对
            return kpg.generateKeyPair();
        } catch (Exception e) {
            LOG.error("generate sm2 key pair failed:{}", e.getMessage(), e);
            return null;
        }
    }

    /**
     * sm3算法加密
     *
     * @param params 待加密字符串
     * @param key    密钥
     * @return 返回加密后，固定长度=32的16进制字符串
     */
    public static String encryptPlusSm3(String params, String key) {
        // 将返回的hash值转换成16进制字符串
        // 将字符串转换成byte数组
        byte[] srcData = params.getBytes(StandardCharsets.UTF_8);
        // 调用hash()
        byte[] resultHash = hmac(srcData, key.getBytes(StandardCharsets.UTF_8));
        // 将返回的hash值转换成16进制字符
        return Hex.encodeHexString(resultHash);
    }

    /**
     * sm3算法加密
     * @param params 待加密字符串
     * @return 返回加密后，固定长度=32的16进制字符串
     */
    public static String encryptSm3(String params) {
        // 将返回的hash值转换成16进制字符串
        // 将字符串转换成byte数组
        byte[] srcData = params.getBytes(StandardCharsets.UTF_8);
        // 调用hash()
        byte[] resultHash = sm3Hash(srcData);
        // 将返回的hash值转换成16进制字符串
        return Hex.encodeHexString(resultHash);
    }

    /**
     * 判断源数据与加密数据是否一致
     * 通过验证原数组和生成的hash数组是否为同一数组，验证2者是否为同一数据
     *
     * @param input         原字符串
     * @param encodedParams 16进制字符串
     * @return 校验结果
     */
    public static boolean matchesSm3(String input, String encodedParams) {
        try {
            byte[] srcData = input.getBytes(StandardCharsets.UTF_8);
            byte[] sm3Hash = Hex.decodeHex(encodedParams);
            byte[] newHash = sm3Hash(srcData);
            return Arrays.equals(newHash, sm3Hash);
        } catch (Exception e) {
            LOG.error("sm3对比失败 ", e);
        }
        return false;
    }

    /**
     * SM4对称加解密  加密
     *
     * @param plainText 待加密的字符串
     * @param key  密钥
     * @return  密文
     */
    public static String encryptSm4(String plainText, String key) {
        try {
            // 创建密钥规范
            SecretKeySpec secretKeySpec = new SecretKeySpec(sm4Key(key), ALGORITHM_SM4);
            // 获取Cipher对象实例
            Cipher cipher = Cipher.getInstance(ALGORITHM_NAME_ECB_PADDING7, BouncyCastleProvider.PROVIDER_NAME);
            // 初始化Cipher为加密模式
            cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec);
            // 获取加密byte数组
            byte[] cipherBytes = cipher.doFinal(plainText.getBytes(StandardCharsets.UTF_8));
            // 输出为Base64编码
            return Base64.encodeBase64String(cipherBytes);
        } catch (Exception e) {
            LOG.error("SM4 加密错误", e);
            throw new RuntimeException(e);
        }
    }

    /**
     * SM4对称加解密 解密
     * @param cipherString  密文
     * @param key   密钥
     * @return  明文
     */
    public static byte[] decryptSm4(String cipherString, String key) {
        try {
            // 创建密钥规范
            SecretKeySpec secretKeySpec = new SecretKeySpec(sm4Key(key), ALGORITHM_SM4);
            // 获取Cipher对象实例
            Cipher cipher = Cipher.getInstance(ALGORITHM_NAME_ECB_PADDING7, BouncyCastleProvider.PROVIDER_NAME);
            // 初始化Cipher为解密模式
            cipher.init(Cipher.DECRYPT_MODE, secretKeySpec);
            // 获取加密byte数组
            return cipher.doFinal(Base64.decodeBase64(cipherString));
        } catch (Exception e) {
            LOG.error("SM4 解密错误", e);
            throw new RuntimeException(e);
        }
    }

    /**
     * 生成密钥
     * @return 密钥16位
     * @throws Exception 生成密钥异常
     */
    public static String generateKeySm4() throws Exception {
        KeyGenerator kg = KeyGenerator.getInstance(ALGORITHM_SM4, BouncyCastleProvider.PROVIDER_NAME);
        kg.init(DEFAULT_KEY_SIZE, new SecureRandom());
        return Hex.encodeHexString(kg.generateKey().getEncoded());
    }

    /**
     * 转为 16字节
     * @param key key
     * @return key
     * @throws NoSuchAlgorithmException 异常
     */
    private static byte[] sm4Key (String key) throws NoSuchAlgorithmException {
        MessageDigest sha256 = MessageDigest.getInstance("SHA-256");
        // 使用hash()方法计算字符串的哈希值
        byte[] hash = sha256.digest(key.getBytes(StandardCharsets.UTF_8));
        byte[] kbs = new byte[16];
        System.arraycopy(hash, 0, kbs, 0, 16);
        return kbs;
    }

    /**
     * 利用sm2 算法加解密
     * @param in 输入 字符串或者是字节数组
     * @param key 密钥
     * @param encrypt 是否加密
     * @return 处理后结果
     */
    private static byte[] sm2 (Object in, String key, boolean encrypt){
        try {
            byte[] bs;
            if (in instanceof byte[] ins) {
                bs = ins;
            } else {
                String plainText = in.toString();
                bs = encrypt ? plainText.toLowerCase().getBytes(StandardCharsets.UTF_8) : Base64.decodeBase64(plainText);
            }
            CipherParameters cipherParameters;
            if (encrypt) {
                cipherParameters = ECUtil.generatePublicKeyParameter(sm2Key(key, encrypt, false));
                cipherParameters = new ParametersWithRandom(cipherParameters);
            } else {
                cipherParameters = ECUtil.generatePrivateKeyParameter(sm2Key(key, encrypt, false));
            }
            //数据加解密
            StandardSm2Engine engine = new StandardSm2Engine(new SM3Digest(), SM2Engine.Mode.C1C3C2);
            engine.init(true, cipherParameters);
            return engine.processBlock(bs, 0, bs.length);
        } catch (NoSuchAlgorithmException | InvalidKeySpecException
                 | InvalidKeyException | InvalidCipherTextException e) {
            LOG.error("sm2 {} 加解密{}错误", in, encrypt, e);
            throw new RuntimeException(e);
        }
    }

    /**
     * 生成 sm2的密钥
     * @param key 密钥
     * @param pub 是否公钥
     * @param signature 是否签名
     * @return 密钥
     * @throws NoSuchAlgorithmException 异常
     * @throws InvalidKeySpecException 异常
     */
    private static <T extends Key > T sm2Key (String key, boolean pub, boolean signature) throws NoSuchAlgorithmException, InvalidKeySpecException {
        byte[] pkd = Base64.decodeBase64(key);
        Key k;
        KeyFactory factory = signature ? KeyFactory.getInstance(EC, BOUNCY) : KeyFactory.getInstance(EC);
        if (pub) {
            X509EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(pkd);
            k = factory.generatePublic(publicKeySpec);
        } else {
            PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(pkd);
            k = factory.generatePrivate(privateKeySpec);
        }
        return PojoUtil.cast(k);
    }

    /**
     * 返回长度=32的byte数组 生成对应的hash值
     *
     * @param params 参数
     * @return hash值
     */
    private static byte[] sm3Hash(byte[] params) {
        SM3Digest digest = new SM3Digest();
        digest.update(params, 0, params.length);
        byte[] hash = new byte[digest.getDigestSize()];
        digest.doFinal(hash, 0);
        return hash;
    }

    /**
     * 通过密钥进行加密 指定密钥进行加密
     *
     * @param key   密钥
     * @param input 被加密的byte数组
     * @return 加密
     */
    private static byte[] hmac(byte[] key, byte[] input) {
        KeyParameter keyParameter = new KeyParameter(key);
        SM3Digest digest = new SM3Digest();
        HMac mac = new HMac(digest);
        mac.init(keyParameter);
        mac.update(input, 0, input.length);
        byte[] result = new byte[mac.getMacSize()];
        mac.doFinal(result, 0);
        return result;
    }
}
