package org.iartisan.runtime.jdbc;

import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.executor.resultset.DefaultResultSetHandler;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeHandler;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.iartisan.runtime.bean.Page;
import org.iartisan.runtime.jdbc.annotations.Pagination;
import org.iartisan.runtime.jdbc.dialects.MySQLDialect;
import org.iartisan.runtime.utils.CollectionUtil;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.*;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * <p>
 * mysql 分页插件
 *
 * @author King
 * @since 2017/6/19
 */
@Deprecated
@Intercepts({
        @Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class, Integer.class})
})
public class PaginationInterceptor implements Interceptor {
    //目前只支持mysql的分页查询
    private MySQLDialect mySQLDialect = MySQLDialect.newInstance();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();
        if (target instanceof StatementHandler) {
            RoutingStatementHandler routingStatementHandler = (RoutingStatementHandler) target;
            MetaObject metaObject = MetaObject.forObject(routingStatementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
                    SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
            MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
            if (mappedStatement.getSqlCommandType().equals(SqlCommandType.SELECT)) {
                String mapperId = mappedStatement.getId();
                Class clazz = refelectMapperClass(mapperId);
                String methodName = refelectMethodName(mapperId);
                //判断该方法是否需要分页操作
                if (needPagination(clazz, methodName)) {
                    Connection connection = (Connection) invocation.getArgs()[0];
                    Object parameterObject = routingStatementHandler.getParameterHandler().getParameterObject();
                    Page page = null;
                    if (parameterObject instanceof Page) {
                        page = (Page) parameterObject;
                    } else if (parameterObject instanceof Map) {
                        Map<String, Object> objectMap = (Map<String, Object>) parameterObject;
                        //入参需要标注 @Param("page")
                        page = (Page) objectMap.get("page");
                    }
                    if (null == page) {
                        //如果参数中没有Page对象则报错
                        throw new IllegalArgumentException("分页方法中缺少：org.iartisan.runtime.bean.Page 对象");
                    }
                    BoundSql boundSql = routingStatementHandler.getBoundSql();
                    String querySQL = mySQLDialect.buildPaginationSQL(boundSql.getSql(), (page.getCurrPage() - 1) * page.getPageSize(), page.getPageSize());
                    //计算count条数
                    countTotal(mappedStatement, boundSql, page, connection);
                    //禁用内存分页
                    metaObject.setValue("delegate.boundSql.sql", querySQL);
                    metaObject.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
                    metaObject.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
                }
            }
        }
        return invocation.proceed();
    }

    private static final String SQL_BASE_COUNT = "SELECT COUNT(1) FROM ( %s ) TOTAL";

    /**
     * 计算总条数
     */
    private void countTotal(MappedStatement mappedStatement, BoundSql boundSql, Page page, Connection connection) throws SQLException, IllegalAccessException {
        PreparedStatement statement = connection.prepareStatement(String.format(SQL_BASE_COUNT, boundSql.getSql()));
        setParameters(mappedStatement, statement, boundSql);
        int total = 0;
        try (ResultSet resultSet = statement.executeQuery()) {
            if (resultSet.next()) {
                total = resultSet.getInt(1);
            }
        }
        page.setTotalRecords(total);
    }

    private Field getAdditionalParametersField() {
        try {
            Field additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
            additionalParametersField.setAccessible(true);
            return additionalParametersField;
        } catch (NoSuchFieldException e) {
            // ignored, Because it will never happen.
        }
        return null;
    }

    public void setParameters(MappedStatement mappedStatement, PreparedStatement ps, BoundSql boundSql) throws SQLException, IllegalAccessException {
        final Object parameterObject = boundSql.getParameterObject();
        final TypeHandlerRegistry typeHandlerRegistry = mappedStatement.getConfiguration().getTypeHandlerRegistry();
        Configuration configuration = mappedStatement.getConfiguration();
        // 反射获取动态参数
        Map<String, Object> additionalParameters = null;
        additionalParameters = (Map<String, Object>) getAdditionalParametersField().get(boundSql);
        ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        if (parameterMappings != null) {
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    if (boundSql.hasAdditionalParameter(propertyName)) {//issue#448 ask first for additional params
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (parameterObject == null) {
                        value = null;
                    } else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                        value = parameterObject;
                    } else {
                        MetaObject metaObject = configuration.newMetaObject(parameterObject);
                        value = metaObject.getValue(propertyName);
                        if (value == null && CollectionUtil.isNotEmpty(additionalParameters)) {
                            value = additionalParameters.get(propertyName);
                        }
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    typeHandler.setParameter(ps, i + 1, value, jdbcType);
                }
            }
        }
    }

    private Class refelectMapperClass(String sqlId) throws ClassNotFoundException {
        int mapperPosition = sqlId.lastIndexOf(".");
        Class clazz = Class.forName(sqlId.substring(0, mapperPosition));
        return clazz;
    }

    private String refelectMethodName(String sqlId) {
        int mapperPosition = sqlId.lastIndexOf(".");
        String methodName = sqlId.substring(mapperPosition + 1);
        return methodName;
    }

    public boolean needPagination(Class clazz, String methodName) {
        Method[] methods = clazz.getMethods();
        for (Method method : methods) {
            if (method.getName().equals(methodName)) {
                Pagination pagination = method.getAnnotation(Pagination.class);
                if (null != pagination) {
                    return true;
                }
                break;
            }
        }
        return false;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }
}
