package cn.sinozg.applet.common.interceptor;

import cn.sinozg.applet.common.annotation.OptIgnore;
import cn.sinozg.applet.common.constant.BaseConstants;
import cn.sinozg.applet.common.core.base.BaseEntity;
import cn.sinozg.applet.common.utils.PojoUtil;
import cn.sinozg.applet.opt.cache.OptLogCache;
import cn.sinozg.applet.opt.cache.OptLogThreadCache;
import cn.sinozg.applet.opt.config.OptLogProperties;
import cn.sinozg.applet.opt.constant.OptLogConstant;
import cn.sinozg.applet.opt.enums.ModeEnum;
import cn.sinozg.applet.opt.module.OptLogCycleInfo;
import cn.sinozg.applet.opt.module.OptLogTableDetail;
import cn.sinozg.applet.opt.module.OptLogTableInfo;
import cn.sinozg.applet.opt.module.OptMapperColumn;
import cn.sinozg.applet.opt.module.OptMapperTable;
import cn.sinozg.applet.opt.module.OptTieRecord;
import cn.sinozg.applet.opt.module.OptTieRelation;
import cn.sinozg.applet.opt.util.OptUtil;
import com.baomidou.mybatisplus.core.enums.SqlMethod;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.override.MybatisMapperProxy;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.MybatisUtils;
import com.baomidou.mybatisplus.core.toolkit.TableNameParser;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import java.lang.reflect.Field;
import java.lang.reflect.Proxy;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.SqlSession;
import org.apache.ibatis.type.TypeHandlerRegistry;

/**
 * 原生mybatis拦截器实现类 batch
 * @Description
 * @Copyright Copyright (c) 2024
 * @author xieyubin
 * @since 2024-02-27 18:24:18
 */
@Slf4j
@Intercepts({
        @Signature(type = StatementHandler.class, method = "batch", args = {Statement.class}),
        @Signature(type = StatementHandler.class, method = "update", args = {Statement.class}),
})
public class OptLogMybatisInterceptor implements Interceptor {

    private final OptLogProperties properties;

    private static final String PLACEHOLDER = "\\?";

    public OptLogMybatisInterceptor(OptLogProperties properties){
        this.properties = properties;
    }
    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        interceptionChain(invocation);
        return invocation.proceed();
    }

    /**
     * 记录日志
     * @param invocation Invocation
     */
    private void interceptionChain (Invocation invocation){
        // 是否执行,当前线程不存在值，则跳过，当前类/方如果存在 @OptIgnore，则跳过
        if (OptLogThreadCache.empty()) {
            return;
        }
        try {
            StatementHandler statementHandler = getTarget(invocation.getTarget());
            MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
            // 只拦截update方法
            MappedStatement mappedStatement = (MappedStatement) metaObject.getValue(OptLogConstant.DELEGATE_MAPPED_STATEMENT);
            SqlCommandType type = mappedStatement.getSqlCommandType();
            // 增删改
            if (SqlCommandType.UPDATE != type && SqlCommandType.DELETE != type && SqlCommandType.INSERT != type) {
                return;
            }
            // 获取执行方法id
            String methodId = mappedStatement.getId();
            // 设置了 忽略
            if (isIgnoreMethod(methodId)) {
                return;
            }
            // 获取BoundSql对象
            BoundSql boundSql = (BoundSql) metaObject.getValue(OptLogConstant.DELEGATE_BOUND_SQL);
            // 获取实体对应的表和列名信息
            ImmutablePair<TableInfo, OptMapperTable> tablePair = entityTable(boundSql.getSql());
            if (tablePair == null) {
                return;
            }
            // 获取真实修改的字段列表
            List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
            if (parameterMappings == null) {
                log.warn("未发现更新的字段");
                return;
            }
            // 处理修改记录
            handleDataLog(methodId, boundSql, tablePair.getRight(), type, tablePair.getLeft(), mappedStatement.getConfiguration());
        } catch (Exception e) {
            log.error("修改记录异常", e);
        }
    }


    /**
     * 设置单次sql 的基本信息
     * @param methodId 请求id
     * @param boundSql boundSql
     * @param mapperTable mapperTable
     * @param type 类型
     * @param tableInfo 表配置
     * @param configuration configuration
     */
    private void handleDataLog (String methodId, BoundSql boundSql, OptMapperTable mapperTable, SqlCommandType type,
                                            TableInfo tableInfo, Configuration configuration){
        OptLogTableDetail detail = new OptLogTableDetail();
        // 设置当前会话唯一标识
        detail.setTraceId(OptLogThreadCache.getTraceId());
        // 方法ID
        detail.setMethodId(methodId);
        // 表名称
        detail.setTableName(mapperTable.getTableName());
        detail.setTableDesc(mapperTable.getTableDesc());
        detail.setClazz(mapperTable.getType());
        // 获取到映射关系 设置 主键 外键
        OptLogCycleInfo cycleInfo = OptLogThreadCache.get();
        Class<?> tie = cycleInfo.getMain().getTie();
        OptTieRelation<?> relation = OptLogCache.optTie(tie);
        if (relation != null) {
            Map<Class<?>, OptTieRecord> recordMap = relation.getKeyMap();
            OptTieRecord record = recordMap.get(mapperTable.getType());
            if (record != null) {
                detail.setIdName(record.getIdName());
                detail.setForeignIdName(record.getForeignIdName());
            }
        }
        Map<String, Map<String, Object>> oldValues = null;
        // 删除 或者 修改
        if (type != SqlCommandType.INSERT) {
            oldValues = oldValuesByDb(configuration, boundSql, mapperTable, tableInfo, detail.getForeignIdName());
        }
        List<Map<String, Object>> newValues = null;
        // 修改后的数据map
        if (type != SqlCommandType.DELETE) {
            // 参数 可以是list
            Object paramObject = sqlParams(boundSql);
            if (paramObject == null) {
                return;
            }
            newValues = newValueMap(paramObject, mapperTable, boundSql.getParameterMappings());
        }
        ModeEnum modeEnum = modeEnum(type, newValues, oldValues);
        if (modeEnum != null) {
            List<OptLogTableInfo> tbs = OptUtil.compare(newValues, oldValues, mapperTable.getPropertyMap(), detail, modeEnum);
            if (CollectionUtils.isNotEmpty(tbs)) {
                // 添加到当前线程变量的集合中，等待事务成功后，统一保存
                OptLogThreadCache.adds(tbs);
            }
        }
    }

    /**
     * 判定执行类型
     * @param type 类型
     * @param newValues 新值
     * @param oldValues 旧值
     * @return 类型
     */
    private ModeEnum modeEnum (SqlCommandType type, List<Map<String, Object>> newValues, Map<String, Map<String, Object>> oldValues){
        ModeEnum mode = null;
        if (SqlCommandType.UPDATE == type && CollectionUtils.isNotEmpty(newValues) && MapUtils.isNotEmpty(oldValues)) {
            mode = ModeEnum.UPDATE;
        } else if (SqlCommandType.INSERT == type && CollectionUtils.isNotEmpty(newValues) && MapUtils.isEmpty(oldValues)) {
            mode = ModeEnum.ADD;
        } else if (SqlCommandType.DELETE == type && CollectionUtils.isEmpty(newValues) && MapUtils.isNotEmpty(oldValues)) {
            mode = ModeEnum.DELETE;
        }
        return mode;
    }


    /**
     * 从数据库根据sql 条件查询原始数据转为 id:需要记录字段map
     * @param configuration configuration
     * @param boundSql boundSql
     * @param mapperTable mapperTable
     * @param tableInfo tableInfo
     * @param foreignIdName foreignIdName
     * @return 集合map
     */
    private Map<String, Map<String, Object>> oldValuesByDb (Configuration configuration, BoundSql boundSql, OptMapperTable mapperTable, TableInfo tableInfo, String foreignIdName){
        Object paramObj = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        String originalSql = boundSql.getSql().replaceAll("\\s+", StringUtils.SPACE);
        if (CollectionUtils.isNotEmpty(parameterMappings) && paramObj != null) {
            TypeHandlerRegistry registry = configuration.getTypeHandlerRegistry();
            // 1个参数
            if (registry.hasTypeHandler(paramObj.getClass())) {
                originalSql = originalSql.replaceFirst(PLACEHOLDER, sqlValue(paramObj));
            } else {
                MetaObject metaObject = configuration.newMetaObject(paramObj);
                for (ParameterMapping pm : parameterMappings) {
                    String proName = pm.getProperty();
                    if (metaObject.hasGetter(proName)) {
                        Object obj = metaObject.getValue(proName);
                        originalSql = originalSql.replaceFirst(PLACEHOLDER, sqlValue(obj));
                    } else if (boundSql.hasAdditionalParameter(proName)) {
                        Object obj = boundSql.getAdditionalParameter(proName);
                        originalSql = originalSql.replaceFirst(PLACEHOLDER, sqlValue(obj));
                    }
                }
            }

        }
        int index = indexOfSqlStart(originalSql);
        if (index > 0) {
            originalSql = originalSql.substring(index);
        }
        log.info("执行SQL：{}", originalSql);
        String[] sqlArray = StringUtils.split(originalSql, ";");
        // 查询更新前数据
        String currentNamespace = tableInfo.getCurrentNamespace();
        MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
        SqlSession sqlSession = mybatisMapperProxy.getSqlSession();
        String statementSql = currentNamespace + BaseConstants.SPOT + SqlMethod.SELECT_LIST.getMethod();
        Map<String, Map<String, Object>> listMap = new HashMap<>(16);
        try {
            for (String sql : sqlArray) {
                Map<String, Map<String, Object>> oneMap = oneSql(sql, sqlSession, statementSql, boundSql, mapperTable, foreignIdName);
                if (oneMap != null) {
                    listMap.putAll(oneMap);
                }
            }
        } catch (Exception e) {
            log.error("获取sql 语句错误！" , e);
        }
        if (MapUtils.isEmpty(listMap)) {
            return null;
        }
        return listMap;
    }

    /**
     * 可能存在批量sql 执行单次sql
     * @param originalSql sql语句
     * @param sqlSession Session
     * @param statementSql statementSql
     * @param boundSql boundSql
     * @param mapperTable 映射
     * @param foreignIdName 外键
     * @return 数据集
     */
    private Map<String, Map<String, Object>> oneSql (String originalSql, SqlSession sqlSession, String statementSql, BoundSql boundSql,
                                                     OptMapperTable mapperTable, String foreignIdName){
        String sql = originalSql.replace("where", "WHERE");
        // 设置sql用于执行完后查询新数据
        String selectSql = "AND " + sql.substring(sql.lastIndexOf("WHERE") + 5);
        Map<String, Object> map = new HashMap<>(16);
        map.put(Constants.WRAPPER, Wrappers.query().eq("1", 1).last(selectSql));
        List<?> data = sqlSession.selectList(statementSql, map);
        if (CollectionUtils.isNotEmpty(data)) {
            Map<String, Map<String, Object>> mapList = new HashMap<>(16);
            for (Object d : data) {
                Map<String, Object> vs = valueMap(d, mapperTable, boundSql.getParameterMappings(), foreignIdName);
                mapList.put(MapUtils.getString(vs, mapperTable.getIdPropertyName()), vs);
            }
            return mapList;
        }
        return null;
    }

    /**
     * sql 值拼接
     * @param obj 对象
     * @return 值
     */
    private String sqlValue (Object obj){
        String v = StringUtils.SPACE;
        if (obj instanceof String) {
            v = "'" + obj + "'";
        } else if (obj != null){
            v = obj.toString();
        }
        return v;
    }


    /**
     * 获取sql语句开头部分
     *
     * @param sql ignore
     * @return ignore
     */
    private int indexOfSqlStart(String sql) {
        String upperCaseSql = sql.toUpperCase();
        Set<Integer> set = new HashSet<>();
        set.add(upperCaseSql.indexOf("SELECT "));
        set.add(upperCaseSql.indexOf("UPDATE "));
        set.add(upperCaseSql.indexOf("INSERT "));
        set.add(upperCaseSql.indexOf("DELETE "));
        set.remove(-1);
        if (CollectionUtils.isEmpty(set)) {
            return -1;
        }
        List<Integer> list = new ArrayList<>(set);
        list.sort(Comparator.naturalOrder());
        return list.get(0);
    }


    /**
     * 是否忽略
     * @param methodId 方法
     * @return 是否忽略
     * @throws Exception 异常
     */
    private boolean isIgnoreMethod (String methodId) throws Exception {
        int lastIndex = methodId.lastIndexOf(BaseConstants.SPOT);
        String mapperClassName = methodId.substring(0, lastIndex);
        String mapperMethodName = methodId.substring(lastIndex + 1);
        Class<?> mapperClass = Class.forName(mapperClassName);
        // 如果类上有忽略注解，则跳过
        OptIgnore classIgnore = mapperClass.getAnnotation(OptIgnore.class);
        if (classIgnore != null) {
            return true;
        }
        // 如果mapper方法上有忽略注解，则跳过
        return OptLogCache.ignoreMethod(mapperClass, mapperMethodName);
    }


    /**
     * 获取实体对象类型
     * @param sql sql
     * @return 对象类型
     */
    private ImmutablePair<TableInfo, OptMapperTable> entityTable (String sql){
        // 获取实体参数对象
        TableNameParser parser = new TableNameParser(sql);
        Collection<String> tables = parser.tables();
        if (CollectionUtils.isEmpty(tables)) {
            log.warn("没有获取到表名称！！");
            return null;
        }
        String tableName = tables.iterator().next();
        TableInfo tableInfo = TableInfoHelper.getTableInfo(tableName);
        if (tableInfo == null) {
            log.warn("实体对象表为空！");
            return null;
        }
        Class<?> entityType = tableInfo.getEntityType();
        // 获取实体对应的表和列名信息
        OptMapperTable mapperTable = OptLogCache.getCacheTableInfo(entityType, properties);
        if (!mapperTable.isRecord()) {
            if (!mapperTable.isLogTable()) {
                log.warn("没有需要记录的字段！");
            }
            return null;
        }
        return ImmutablePair.of(tableInfo, mapperTable);
    }

    /**
     * 获取参数，参数必须为实体对象
     * @param bound bound
     * @return 实体
     */
    private Object sqlParams (BoundSql bound){
        Object paramObject = bound.getParameterObject();
        Object params = null;
        if (paramObject instanceof Map) {
            Map<String, Object> map = PojoUtil.cast(paramObject);
            params = new ArrayList<>(map.values()).get(0);
        } else if (paramObject instanceof BaseEntity) {
            params = PojoUtil.cast(paramObject);
        }
        if (!(params instanceof List) && !(params instanceof BaseEntity)) {
            log.warn("非实体对象类型参数不支持！");
            return null;
        }
        return params;
    }

    /**
     * 获取到参数组成的原始数据 可能没有主键
     * update 可以是一个 eg：where for_key = ? and stat = ?
     *         可以是批量更新
     * insert 单个插入或者批量插入
     * @param object 参数
     * @param mapperTable 实体对象配置
     * @return 集合
     */
    private List<Map<String, Object>> newValueMap (Object object, OptMapperTable mapperTable, List<ParameterMapping> mappings){
        List<Map<String, Object>> list = new ArrayList<>();
        if (object instanceof List) {
            List<?> data = PojoUtil.cast(object);
            data.forEach(d -> list.add(valueMap(d, mapperTable, mappings, null)));
        } else {
            list.add(valueMap(object, mapperTable, mappings, null));
        }
        return list;
    }

    private Map<String, Object> valueMap (Object bean, OptMapperTable mapperTable, List<ParameterMapping> mappings, String foreignIdName) {
        Map<String, OptMapperColumn> propertyMap = mapperTable.getPropertyMap();
        Map<String, Object> valueMap = new LinkedHashMap<>();
        try {
            for (ParameterMapping mapping : mappings) {
                String propertyName = mapping.getProperty();
                // 别名
                if (StringUtils.contains(propertyName, BaseConstants.SPOT)) {
                    propertyName = StringUtils.substringAfter(propertyName, BaseConstants.SPOT);
                }
                if (propertyMap.containsKey(propertyName)) {
                    valueMap.put(propertyName, getProperty(bean, propertyName));
                }
            }
            if (StringUtils.isNotBlank(foreignIdName) && !valueMap.containsKey(foreignIdName)) {
                valueMap.put(foreignIdName, getProperty(bean, foreignIdName));
            }
            // id
            String idName = mapperTable.getIdPropertyName();
            if (StringUtils.isNotBlank(idName) && !valueMap.containsKey(idName)) {
                valueMap.put(idName, getProperty(bean, idName));
            }
        } catch (Exception e) {
            log.error("反射获取值错误！", e);
        }
        return valueMap;
    }

    /**
     * 获取字段
     * @param bean 实体
     * @param propertyName 属性名称
     * @return 值
     * @throws IllegalAccessException 异常
     */
    private Object getProperty (Object bean, String propertyName) throws IllegalAccessException {
        Field field = PojoUtil.fieldByName(bean.getClass(), propertyName);
        field.setAccessible(true);
        return field.get(bean);
    }

    private <T> T getTarget(Object target) {
        if (Proxy.isProxyClass(target.getClass())) {
            MetaObject metaObject = SystemMetaObject.forObject(target);
            return getTarget(metaObject.getValue(OptLogConstant.H_TARGET));
        }
        return PojoUtil.cast(target);
    }
}
