package cn.zzq0324.radish.components.wechat.crypto;

import cn.zzq0324.radish.common.util.MessageDigestUtils;
import cn.zzq0324.radish.common.util.XmlUtils;
import cn.zzq0324.radish.components.wechat.officialaccount.dto.calllback.ReplyData;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Random;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.springframework.util.Assert;
import org.springframework.util.Base64Utils;

/**
 * 消息加解密，异常java.security.InvalidKeyException:illegal Key Size需要下载JDK的jce文件覆盖
 *
 * @author: zzq0324
 * @since : 1.0.0
 */
public class MessageCrypto {

  // 加密密钥
  private byte[] aesKey;
  // 令牌
  private String token;
  // 应用ID
  private String appId;

  public MessageCrypto(String token, String encodingAesKey, String appId) {
    Assert.isTrue(encodingAesKey.length() == 43, "illegal encodingAesKey, length must be 43!");

    this.token = token;
    this.appId = appId;
    this.aesKey = Base64Utils.decodeFromString(encodingAesKey + "=");
  }

  /**
   * 解密
   *
   * @param timestamp    请求URL中的时间戳
   * @param nonce        请求URL中的随机串
   * @param msgSignature 请求URL中的签名，加上请求数据计算后的签名
   * @param postData     请求数据
   * @return 返回解密后的数据
   */
  public String decrypt(String timestamp, String nonce, String msgSignature, String postData) {
    // 验证安全签名，加上数据再生成签名
    String signature = MessageDigestUtils.getSHA1(getSha1Data(token, timestamp, nonce, postData));

    if (!signature.equals(msgSignature)) {
      throw new RuntimeException("msgSignature not match!");
    }

    // 解密
    return decrypt(postData);
  }

  /**
   * 将公众平台回复用户的消息加密打包.
   * <ol>
   * 	<li>对要发送的消息进行AES-CBC加密</li>
   * 	<li>生成安全签名</li>
   * 	<li>将消息密文和安全签名打包成xml格式</li>
   * </ol>
   *
   * @param replyMsg 公众平台待回复用户的消息，xml格式的字符串
   * @return 加密后的可以直接回复用户的密文，包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串
   */
  public String encrypt(String replyMsg, String timestamp, String nonce) {
    // 加密
    String encrypt = encrypt(getRandomStr(), replyMsg);

    String signature = MessageDigestUtils.getSHA1(getSha1Data(token, timestamp, nonce, encrypt));

    ReplyData replyData = new ReplyData(encrypt, signature, timestamp, nonce);

    return XmlUtils.toXML(replyData);
  }

  // 生成4个字节的网络字节序
  private byte[] getNetworkBytesOrder(int sourceNumber) {
    byte[] orderBytes = new byte[4];
    orderBytes[3] = (byte) (sourceNumber & 0xFF);
    orderBytes[2] = (byte) (sourceNumber >> 8 & 0xFF);
    orderBytes[1] = (byte) (sourceNumber >> 16 & 0xFF);
    orderBytes[0] = (byte) (sourceNumber >> 24 & 0xFF);
    return orderBytes;
  }

  // 还原4个字节的网络字节序
  private int recoverNetworkBytesOrder(byte[] orderBytes) {
    int sourceNumber = 0;
    for (int i = 0; i < 4; i++) {
      sourceNumber <<= 8;
      sourceNumber |= orderBytes[i] & 0xff;
    }
    return sourceNumber;
  }

  // 随机生成16位字符串
  private String getRandomStr() {
    String base = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
    Random random = new Random();
    StringBuffer sb = new StringBuffer();
    for (int i = 0; i < 16; i++) {
      int number = random.nextInt(base.length());
      sb.append(base.charAt(number));
    }
    return sb.toString();
  }

  /**
   * 对明文进行加密.
   *
   * @param text 需要加密的明文
   * @return 加密后base64编码的字符串
   */
  private String encrypt(String randomStr, String text) {
    ByteGroup byteCollector = new ByteGroup();
    byte[] randomStrBytes = randomStr.getBytes(StandardCharsets.UTF_8);
    byte[] textBytes = text.getBytes(StandardCharsets.UTF_8);
    byte[] networkBytesOrder = getNetworkBytesOrder(textBytes.length);
    byte[] appIdBytes = appId.getBytes(StandardCharsets.UTF_8);

    // randomStr + networkBytesOrder + text + appid
    byteCollector.addBytes(randomStrBytes);
    byteCollector.addBytes(networkBytesOrder);
    byteCollector.addBytes(textBytes);
    byteCollector.addBytes(appIdBytes);

    // ... + pad: 使用自定义的填充方式对明文进行补位填充
    byte[] padBytes = PKCS7Encoder.encode(byteCollector.size());
    byteCollector.addBytes(padBytes);

    // 获得最终的字节流, 未加密
    byte[] unencrypted = byteCollector.toBytes();

    try {
      // 设置加密模式为AES的CBC模式
      Cipher cipher = Cipher.getInstance("AES/CBC/NoPadding");
      SecretKeySpec keySpec = new SecretKeySpec(aesKey, "AES");
      IvParameterSpec iv = new IvParameterSpec(aesKey, 0, 16);
      cipher.init(Cipher.ENCRYPT_MODE, keySpec, iv);

      // 加密
      byte[] encrypted = cipher.doFinal(unencrypted);

      // 使用BASE64对加密后的字符串进行编码
      String base64Encrypted = Base64Utils.encodeToString(encrypted);

      return base64Encrypted;
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }

  /**
   * 对密文进行解密.
   *
   * @param text 需要解密的密文
   * @return 解密得到的明文
   */
  private String decrypt(String text) {
    byte[] original;
    try {
      // 设置解密模式为AES的CBC模式
      Cipher cipher = Cipher.getInstance("AES/CBC/NoPadding");
      SecretKeySpec key_spec = new SecretKeySpec(aesKey, "AES");
      IvParameterSpec iv = new IvParameterSpec(Arrays.copyOfRange(aesKey, 0, 16));
      cipher.init(Cipher.DECRYPT_MODE, key_spec, iv);

      // 使用BASE64对密文进行解码
      byte[] encrypted = Base64Utils.decodeFromString(text);

      // 解密
      original = cipher.doFinal(encrypted);
    } catch (Exception e) {
      throw new RuntimeException(e);
    }

    String messageContent, fromAppId;
    try {
      // 去除补位字符
      byte[] bytes = PKCS7Encoder.decode(original);

      // 分离16位随机字符串,网络字节序和AppId
      byte[] networkOrder = Arrays.copyOfRange(bytes, 16, 20);

      int xmlLength = recoverNetworkBytesOrder(networkOrder);

      messageContent = new String(Arrays.copyOfRange(bytes, 20, 20 + xmlLength), StandardCharsets.UTF_8);
      fromAppId = new String(Arrays.copyOfRange(bytes, 20 + xmlLength, bytes.length), StandardCharsets.UTF_8);
    } catch (Exception e) {
      throw e;
    }

    // appId不相同的情况
    if (!fromAppId.equals(appId)) {
      throw new RuntimeException("AppId not match");
    }

    return messageContent;
  }


  /**
   * 校验签名是否正确
   */
  public void validateSign(String timestamp, String nonce, String signature) {
    String calculateSign = MessageDigestUtils.getSHA1(getSha1Data(token, timestamp, nonce));

    if (!calculateSign.equals(signature)) {
      throw new RuntimeException("signature not match!");
    }
  }

  private static String getSha1Data(String... data) {
    StringBuffer stringBuffer = new StringBuffer();
    // 字符串排序
    Arrays.sort(data);
    for (int i = 0; i < data.length; i++) {
      stringBuffer.append(data[i]);
    }

    return stringBuffer.toString();
  }
}
