/*
 * Copyright 2023-2025 Licensed under the AGPL License
 */
package plus.hiver.common.aop;

import lombok.RequiredArgsConstructor;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import plus.hiver.common.annotation.Decrypt;
import plus.hiver.common.config.sm2.Sm2Service;
import plus.hiver.common.exception.HiverException;

import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.nio.charset.StandardCharsets;
import java.util.Base64;

/**
 * SM2解密AOP
 *
 * <p>
 * 尊重知识产权，CV 请保留版权，海文科技 https://hiver.cc 出品，不允许非法使用，后果自负
 * </p>
 *
 * @author Yazhi Li
 */
@Aspect
@Component
@RequiredArgsConstructor
public class DecryptAspect {
    private final Sm2Service sm2Service;

    @Around("execution(* *(.., @plus.hiver.common.annotation.Decrypt (*), ..)) && " +
            "(@within(org.springframework.stereotype.Controller) || " +
            "@within(org.springframework.web.bind.annotation.RestController))")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        if (!sm2Service.isEnabled()) {
            return joinPoint.proceed(); // 直接跳过解密
        }

        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        Parameter[] parameters = method.getParameters();
        Object[] args = joinPoint.getArgs();

        for (int i = 0; i < parameters.length; i++) {
            if (shouldDecrypt(parameters[i])) {
                args[i] = processDecryption(args[i], parameters[i]);
            }
        }

        return joinPoint.proceed(args);
    }

    private boolean shouldDecrypt(Parameter parameter) {
        return parameter.isAnnotationPresent(Decrypt.class);
    }

    private Object processDecryption(Object arg, Parameter parameter) {
        if (!(arg instanceof String)) {
            return arg;
        }
        try {
            String encrypted = (String) arg;
            // 处理URL编码
            String decoded = sm2Service.decodeFromUrl(base64Decode(encrypted));
            return sm2Service.decrypt(decoded);
        } catch (Exception e) {
            throw new HiverException("参数解密失败: " + parameter.getName(), e);
        }
    }

    private String base64Decode(String data) {
        return new String(Base64.getDecoder().decode(data), StandardCharsets.UTF_8);
    }
}
