package top.doudou.mybatis.plus;


import com.baomidou.mybatisplus.annotation.TableField;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.binding.MapperProxyFactory;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.stereotype.Component;
import top.doudou.core.exception.CustomException;
import top.doudou.core.exception.ExceptionUtils;
import top.doudou.core.properties.CustomLogProperties;
import top.doudou.core.random.RandomUtils;
import top.doudou.core.system.SystemMonitorUtil;
import top.doudou.core.util.FieldUtils;
import top.doudou.core.util.HumpUtil;
import top.doudou.core.util.file.WriteLogToFile;
import top.doudou.mybatis.plus.entity.SqlLogDto;
import top.doudou.mybatis.plus.entity.SqlResultDto;
import top.doudou.mybatis.plus.utils.MyBatisUtils;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Modifier;
import java.sql.Statement;
import java.text.DateFormat;
import java.util.*;
import java.util.regex.Matcher;

import static top.doudou.base.util.ServletUtils.getRequest;
import static top.doudou.core.constant.CommonConstant.REQUEST_UUID;

/**
 * MyBatis有四大核心对象：
 * （1）ParameterHandler：处理SQL的参数对象
 * （2）ResultSetHandler：处理SQL的返回结果集
 * （3）StatementHandler：数据库的处理对象，用于执行SQL语句
 * （4）Executor：MyBatis的执行器，用于执行增删改查操作
 * @author 傻男人<244191347@qq.com>
 * @description 自定义mybatis拦截器,格式化SQL输出（只对查询和更新语句做了格式化，其它语句需要手动添加）
 * @date 2020-07-27
 */
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class,
                ResultHandler.class, CacheKey.class, BoundSql.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})
})
@Component
@Slf4j
@ConditionalOnClass(MapperProxyFactory.class)
@EnableConfigurationProperties(CustomLogProperties.class)
public class MybatisInterceptor implements Interceptor {

    @Autowired
    private java.util.concurrent.Executor executorService;

    @Autowired
    private CustomLogProperties customLogProperties;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object statement = MyBatisUtils.getNoProxyTarget(invocation.getTarget());
        if(statement instanceof Statement){
            return showResult(invocation,statement);
        }
        if(statement instanceof MappedStatement){
            return showSql(invocation,statement);
        }
        return invocation.proceed();
    }

    private Object showSql(Invocation invocation, Object statement) throws InvocationTargetException, IllegalAccessException {
        try {
            MappedStatement mappedStatement = (MappedStatement) statement;
            Object parameter = null;
            if (invocation.getArgs().length > 1) {
                parameter = invocation.getArgs()[1];
            }
            BoundSql boundSql = mappedStatement.getBoundSql(parameter);
            Configuration configuration = mappedStatement.getConfiguration();
            SqlLogDto.setCurSentence(getSqLSentence(configuration, boundSql));
            SqlLogDto.setCurSqlId(mappedStatement.getId());
        } catch (Exception localException) {
        }
        long start = System.currentTimeMillis();
        Object result = invocation.proceed();
        long cost = System.currentTimeMillis() - start;
        SqlLogDto.setCurCost(cost);
        getUUIDFromRequest();
        if(result instanceof List){
            SqlLogDto.setCurRows(((List)result).size());
        }else {
            SqlLogDto.setCurResult(result);
        }
        printSql();
        return result;
    }

    /**
     * 从request中获取请求的id
     * @return
     */
    private void getUUIDFromRequest(){
        String uuid = null;
        try{
            uuid = Optional.ofNullable(getRequest().getAttribute(REQUEST_UUID)).map(Object::toString).orElse("");
        }catch (Exception e){
            uuid = "temp"+RandomUtils.randomUUID(12);
        }
        SqlLogDto.setCurRequestId(uuid);
    }

    private String getParameterValue(Object obj) {
        String value = null;
        if ((obj instanceof String)) {
            value = "'" + obj.toString() + "'";
        } else if ((obj instanceof Date)) {
            DateFormat formatter = DateFormat.getDateTimeInstance(2, 2, Locale.CHINA);
            value = "'" + formatter.format(new Date()) + "'";
        } else if (obj != null) {
            value = obj.toString();
        } else {
            value = "";
        }
        return value;
    }

    /**
     * 获取执行sql
     * @param configuration
     * @param boundSql
     * @return
     */
    private String getSqLSentence(Configuration configuration, BoundSql boundSql) {
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
        MetaObject metaObject;
        if ((CollectionUtils.isNotEmpty(parameterMappings)) && (parameterObject != null)) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = sql.replaceFirst("\\?", Matcher.quoteReplacement(getParameterValue(parameterObject)));
            } else {
                metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        sql = sql.replaceFirst("\\?", Matcher.quoteReplacement(getParameterValue(obj)));
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        sql = sql.replaceFirst("\\?", Matcher.quoteReplacement(getParameterValue(obj)));
                    } else {
                        sql = sql.replaceFirst("\\?", "缺失");
                    }
                }
            }
        }
        return sql;
    }

    @Override
    public Object plugin(Object target) {
        if(target instanceof ResultSetHandler){
            return Plugin.wrap(target, this);
        }
        if(target instanceof Executor){
            return Plugin.wrap(target,this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {
    }

    /**
     * 展示结果集
     * @param invocation
     * @param statementObj
     * @return
     * @throws InvocationTargetException
     * @throws IllegalAccessException
     */
    private Object showResult(Invocation invocation, Object statementObj) throws InvocationTargetException, IllegalAccessException {
        List<Object> results = (List<Object>)invocation.proceed();
        try{
            if(CollectionUtils.isNotEmpty(results)){
                Class<?> cls = getRowClass(results);
                SqlResultDto sqlResultDto = new SqlResultDto();
                List<Field> fields = FieldUtils.getAllNoStaticFinalTransientFields(cls);
                int sum = 0;
                String queryField = getQueryField(SqlLogDto.getCurrent().getSentence());
                List<String> queryFieldList = Lists.newArrayList();
                StringJoiner header = new StringJoiner(",  ", " Columns:  ", "");
                if(!("".equals(queryField) || queryField.trim().equals("*"))){
                    Arrays.stream(queryField.split(",")).forEach(item->queryFieldList.add(item.trim()));
                }
                List<Field> matchFieldList = Lists.newArrayList();
                for(Object result:results){
                    if(null == result){
                        continue;
                    }
                    StringJoiner row = new StringJoiner(",  ", " Row:  ", "");
                    for (Field field : fields) {
                        if (Modifier.isStatic(field.getModifiers())) {
                            continue;
                        }
                        if(sum == 0 ){
                            if(CollectionUtils.isNotEmpty(queryFieldList)){
                                String matchField = matchField(field.getName(), queryFieldList);
                                if(!StringUtils.isNotEmpty(matchField)){
                                    continue;
                                }
                                header.add(matchField);
                                matchFieldList.add(field);
                            }else {
                                TableField annotation = field.getAnnotation(TableField.class);
                                header.add(null != annotation ? annotation.value() : field.getName());
                            }
                        }
                        if(matchFieldList.contains(field)){
                            Object value = FieldUtils.getFieldValueNoException(field,result);
                            row.add(null == value ? null : value.toString());
                        }
                    }
                    if(sum == 0){
                        if(CollectionUtils.isEmpty(fields)){
                            header.add(queryField);
                        }
                        sqlResultDto.setHeader(header.toString());
                    }
                    if(results.size() == 1){
                        row.add(result.toString());
                    }
                    sqlResultDto.setColumnList(row.toString());
                    sum ++;
                }
                SqlLogDto.setCurFormatSqlRetDto(sqlResultDto);
            }
        }catch (CustomException customException){
            log.error(customException.getMessage());
        }catch (Exception e) {
            log.error(ExceptionUtils.toString(e));
        }finally{
            return results;
        }
    }

    private String matchField(String fieldName,List<String> list){
        if (CollectionUtils.isEmpty(list)){
            return null;
        }
        String lowerFieldName = fieldName.toLowerCase();
        for (String item:list) {
            String result = item;
            boolean fieldIndex_ = fieldName.indexOf("_") != -1;
            boolean itemIndex_ = item.indexOf("_") != -1;
            if(fieldIndex_ && itemIndex_){
                if(lowerFieldName.equals(item.toLowerCase())){
                    return result;
                }
            }
            if(fieldIndex_ && !itemIndex_){
                String temp = HumpUtil.camelCaseName(fieldName);
                if(temp.toLowerCase().equals(item.toLowerCase())){
                    return result;
                }
            }
            if(!fieldIndex_ && itemIndex_){
                String temp = HumpUtil.camelCaseName(item);
                if(lowerFieldName.equals(temp.toLowerCase())){
                    return result;
                }
            }
            if(!fieldIndex_ && !itemIndex_){
                if(lowerFieldName.equals(item.toLowerCase())){
                    return result;
                }
            }
        }
        return null;
    }

    private String getRowInfo(String row, Object result) {
        return row + result.toString();
    }

    private String getHeaderInfo(String headerInfo,String sentence){
        return headerInfo+getQueryField(sentence);
    }

    private String getQueryField(String sentence){
        if(StringUtils.isBlank(sentence)){
            return "";
        }
        sentence = sentence.trim().toLowerCase();
        if(!sentence.startsWith("select")){
            return "";
        }
        return sentence.split("from")[0].replaceFirst("select","");
    }


    /**
     * 获取查询结果的属性值(防止单属性第一条数据为null值)
     * @param list
     * @return
     */
    private Class<?> getRowClass(List<Object> list){
        for (Object obj:list) {
            if(null != obj){
                return obj.getClass();
            }
        }
        throw new CustomException("所有属性均为空");
    }

    /**
     * mybatis sql
     */
    private StringJoiner getFormatSql(){
        SqlLogDto sqlLogDto = SqlLogDto.getCurrent();
        if(null == sqlLogDto){
            return null;
        }
        if(customLogProperties.isPrintSql()){
            StringJoiner stringJoiner = new StringJoiner(SystemMonitorUtil.getLineBreak(),"","");
            String requestId = sqlLogDto.getRequestId();
            requestId = StringUtils.isNotBlank(requestId)? "["+requestId+"] ":"";
            stringJoiner.add("--------------< sql执行的语句"+requestId+" >--------------")
                    .add("===>  sql id                " + sqlLogDto.getSqlId());
            String sentence = sqlLogDto.getSentence();
            if(StringUtils.isNotEmpty(sentence) && sentence.length() >= 800 && !(sentence.startsWith("select") || sentence.startsWith("SELECT"))){
                stringJoiner.add("===>  sql sentence          sql sentence length greater than 800");
            }else {
                stringJoiner.add("===>  sql sentence          " + sentence);
            }
            stringJoiner.add("===>  sql cost(ms): " + sqlLogDto.getCost()+ (null != sqlLogDto.getRows()?"     rows: " + sqlLogDto.getRows():""));
            return stringJoiner;
        }
        if(customLogProperties.isWriteSqlToFile()){
            WriteLogToFile.asyncLogToFile(customLogProperties.getFilePath(customLogProperties.getSqlLogName()),sqlLogDto.toString());
        }
        return null;
    }

    /**
     * mybatis的字符串sql结果集
     */
    private StringJoiner getSqlResult(){
        StringJoiner stringJoiner = new StringJoiner(SystemMonitorUtil.getLineBreak(),"","");
        SqlResultDto sqlResultDto = SqlLogDto.getCurrent().getFormatSqlRetDto();
        if(null == sqlResultDto){
            return null;
        }
        if(null != sqlResultDto && customLogProperties.isPrintSqlResult()) {
            if (StringUtils.isNotBlank(sqlResultDto.getHeader())) {
                stringJoiner.add("===> " + sqlResultDto.getHeader());
            }
            List<String> columnList = sqlResultDto.getColumnList();
            if (CollectionUtils.isNotEmpty(columnList)) {
                columnList.forEach(item -> stringJoiner.add("===> " + item));
            }
        }
        return stringJoiner;
    }

    private void printSql(){
        try{
            StringJoiner formatSql = getFormatSql();
            if(null == formatSql){
                return;
            }
            StringJoiner sqlResult = getSqlResult();
            if(null != sqlResult){
                formatSql.merge(sqlResult);
            }
            log.info(formatSql.toString());
        }finally {
            SqlLogDto.removeCurrent();
        }
    }
}