package top.doudou.mybatis.plus.utils;

import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

/**
 * @Description MetaObject工具类
 * @author 傻男人<244191347@qq.com>
 * @Date 2021-10-22 10:33
 * @Version V1.0
 */
public class MyBatisUtils {

    /**
     * 获得最终的目标对象
     * @return
     */
    public static Object getNoProxyTarget(Object target){
        //进行对象的绑定
        MetaObject metaObject = SystemMetaObject.forObject(target);
        while(metaObject.hasGetter("h")){
            target = metaObject.getValue("h.target");
            metaObject = SystemMetaObject.forObject(target);
        }
        return target;
    }


    /**
     * 该表StatementHandler中执行的sql
     * @param statementHandler
     * @param sql
     */
    public static void changeExecuteSql(StatementHandler statementHandler,String sql){
        BoundSql boundSql = statementHandler.getBoundSql();
        //MyBatis提供的一个工具类（使用反射修改BoundSql对象中的sql语句）
        MetaObject metaObject = SystemMetaObject.forObject(boundSql);
        metaObject.setValue("sql", sql);
    }

    /**
     * 获取主表form下标的位置
     * @param sql
     * @return -1:表示没有from 其他表示form的下标位置
     */
    public static int getFromIndex(int beginIndex, String sql){
        int fromIndex = sql.indexOf("from", beginIndex);
        if (fromIndex == -1) return -1;

        int count = 0;//括号的计数器
        int selectIndex = fromIndex;//当前需要查询括号的位置下标
        int bIndex = -1;//正括号的下标
        while((bIndex = sql.lastIndexOf("(", selectIndex)) != -1){
            count++;
            selectIndex = bIndex - 1;
        }
        selectIndex = fromIndex;//当前需要查询括号的位置下标
        int eIndex = -1;//反括号的下标
        while((eIndex = sql.lastIndexOf(")", selectIndex)) != -1){
            count--;
            selectIndex = eIndex - 1;
        }
        if (count == 0) {
            return fromIndex;
        } else {
            return getFromIndex(fromIndex + 1, sql);
        }
    }

    /**
     * 获取sql的查询总数的语句
     * @param sql
     * @return
     */
    public static String getSqlTotal(String sql){
        int index = getFromIndex(0, sql);
        return "select count(*) as total " + sql.substring(index);
    }

    /**
     * statementHandler 执行sql
     * @param invocation
     * @param statementHandler
     * @param sql
     * @return
     * @throws SQLException
     */
    public static ResultSet executeSql(Invocation invocation, StatementHandler statementHandler, String sql) throws SQLException{
        //执行查询总条数的sql
        //获得方法的参数
        Connection connection = (Connection) invocation.getArgs()[0];
        PreparedStatement ps = null;
        ResultSet resultSet = null;
        try {
            //创建预编译的会话对象
            ps = connection.prepareStatement(sql);
            //设置sql参数
            statementHandler.parameterize(ps);
            //执行sql
            resultSet = ps.executeQuery();
            return resultSet;
        } catch (SQLException throwables) {
            throw throwables;
        } finally {
            if (resultSet != null) {
                resultSet.close();
            }
            if (ps != null) {
                ps.close();
            }
        }
    }

}