package top.doudou.mybatis.plus.encrypt.interceptor;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.binding.MapperProxyFactory;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.stereotype.Component;
import top.doudou.core.exception.CustomException;
import top.doudou.mybatis.plus.encrypt.KeyCenterUtil;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;

/**
 * @Description mybatis属性加解密拦截器
 * @version: 1.0
 * @Created 傻男人 <244191347@qq.com>
 * @Date 2022-01-13 10:22
 */
@Slf4j
@Component
@ConditionalOnClass(MapperProxyFactory.class)
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
})
public class MybatisEncryptAndDecryptInterceptor  implements Interceptor {

    /**
     * 插入  更新  删除都是更新的操作
     */
    public static final String UPDATE = "update";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        boolean next = true;
        // 获取该sql语句的类型，例如update，insert
        String methodName = invocation.getMethod().getName();
        // 获取该sql语句放入的参数
        Object parameter = invocation.getArgs()[1];
        if(UPDATE.equalsIgnoreCase(methodName)){
            next = false;
            encryptField(parameter);
        }
        // 继续执行sql语句,调用当前拦截的执行方法
        Object returnValue = invocation.proceed();
        if(next){
            try {
                // 当返回值类型为数组集合时，就判断是否需要进行数据解密
                if (returnValue instanceof ArrayList<?>) {
                    List<?> list = (List<?>) returnValue;
                    // 判断结果集的数据是否为空
                    if (CollectionUtils.isEmpty(list)) {
                        return returnValue;
                    }
                    list.forEach(item->{
                        decryptField(item);
                        sensitiveHandler(item);
                    });
                }
            } catch (Exception e) {
                log.info("解密出错  error message:{}",e.getMessage());
                return returnValue;
            }
        }
        return returnValue;
    }

    /**
     * 脱敏处理
     * @param item
     */
    private void sensitiveHandler(Object item) {
    }

    @Override
    public Object plugin(Object target) {
        return Interceptor.super.plugin(target);
    }

    @Override
    public void setProperties(Properties properties) {
        Interceptor.super.setProperties(properties);
    }

    /**
     * 字段解密
     * 扫描带有解密注解的字段进行解密
     *
     * @param <T>
     */
    public <T> void decryptField(T t) {
        // 获取对象的域
        Field[] declaredFields = t.getClass().getDeclaredFields();
        Arrays.stream(declaredFields).forEach(field->{
            if ((field.isAnnotationPresent(DecryptField.class) || field.isAnnotationPresent(EncryptAndDecryptField.class))
                    && field.getType().equals(String.class)) {
                String value = null;
                KeyCenterUtil keyCenterUtil = null;
                if(field.isAnnotationPresent(DecryptField.class)){
                    DecryptField annotation = field.getAnnotation(DecryptField.class);
                    value = annotation.value();
                    keyCenterUtil = newInstance(annotation.keyCenterUtil());
                }else {
                    EncryptAndDecryptField annotation = field.getAnnotation(EncryptAndDecryptField.class);
                    value = annotation.decryptKey();
                    keyCenterUtil = newInstance(annotation.keyCenterUtil());
                }
                field.setAccessible(true);
                try {
                    String fieldValue = (String) field.get(t);
                    if (StringUtils.isNotEmpty(fieldValue)) {
                        field.set(t, keyCenterUtil.decrypt(fieldValue,value));
                    }
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        });
    }

    /**
     * 字段加密
     * 扫描带有解密注解的字段进行加密
     * @param <T>
     */
    public <T> void encryptField(T t) {
        Field[] declaredFields = t.getClass().getDeclaredFields();
        Arrays.stream(declaredFields).forEach(field->{
            if ((field.isAnnotationPresent(EncryptField.class) || field.isAnnotationPresent(EncryptAndDecryptField.class))
                    && field.getType().equals(String.class)) {
                String value = null;
                KeyCenterUtil keyCenterUtil = null;
                if(field.isAnnotationPresent(EncryptField.class)){
                    EncryptField annotation = field.getAnnotation(EncryptField.class);
                    value = annotation.value();
                    keyCenterUtil = newInstance(annotation.keyCenterUtil());
                }else {
                    EncryptAndDecryptField annotation = field.getAnnotation(EncryptAndDecryptField.class);
                    value = annotation.encryptKey();
                    keyCenterUtil = newInstance(annotation.keyCenterUtil());
                }
                field.setAccessible(true);
                try {
                    // 获取这个值
                    String fieldValue = (String) field.get(t);
                    // 判断这个值是否为空
                    if (StringUtils.isNotEmpty(fieldValue)) {
                        field.set(t, keyCenterUtil.encrypt(fieldValue,value));
                    }
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }

            }
        });
    }

    private <T>T newInstance(Class<T> target){
        try {
            return target.newInstance();
        }catch (Exception e){
            log.error("target：{} 反射出现错误,错误的原因：{}",target.getClass().getName(),e.getMessage());
            throw new CustomException(e);
        }
    }
}
