package com.walker.jdbc.dao;

import com.walker.db.DatabaseType;
import com.walker.db.page.GenericPager;
import com.walker.db.page.ListPageContext;
import com.walker.db.page.PageSearch;
import com.walker.infrastructure.utils.StringUtils;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class SqlDaoSupport extends JdbcDaoImpl{

    @Override
    public void update(String sql){
        assert(StringUtils.isNotEmpty(sql));
        this.getJdbcTemplate().update(sql);
    }

    @Override
    public int update(String sql, Object[] args){
        assert(StringUtils.isNotEmpty(sql));
        return this.getJdbcTemplate().update(sql, args);
    }

    @Override
    public <T> List<T> sqlQuery(String sql, RowMapper<T> rowMapper){
        return sqlQuery(sql, null, rowMapper);
    }
    @Override
    public <T> List<T> sqlQuery(String sql, Object[] args, RowMapper<T> rowMapper){
        assert (StringUtils.isNotEmpty(sql));
        assert (rowMapper != null);
        return this.getJdbcTemplate().query(sql, rowMapper, args);
    }

    /**
     * 返回map集合记录，不再依赖hibernate
     * @param sql
     * @param args
     * @return
     */
    @Override
    public List<Map<String, Object>> sqlQueryListMap(String sql, Object[] args){
        return this.getJdbcTemplate().queryForList(sql, args);
    }

    /**
     * 通过动态泛型的形式，给定返回对象类型。这样就可以在每个DAO中随时使用。
     * @param sql
     * @param args
     * @param rowMapper
     * @return
     */
    @Override
    public <E> GenericPager<E> sqlGeneralQueryPager(String sql, Object[] args, RowMapper<E> rowMapper){
        // 2023-04-01 前后端界面使用:PageSearch分页参数，如果存在优先使用。
        PageSearch pageSearch = ListPageContext.getPageSearch();
        if(pageSearch != null){
            return sqlGeneralQueryPager(sql, args, rowMapper
                    , pageSearch.getPageIndex()
                    , pageSearch.getPageSize());
        } else {
            return sqlGeneralQueryPager(sql, args, rowMapper
                    , ListPageContext.getCurrentPageIndex()
                    , ListPageContext.getCurrentPageSize());
        }
//        return sqlGeneralQueryPager(sql, args, rowMapper
//                , ListPageContext.getCurrentPageIndex()
//                , ListPageContext.getCurrentPageSize());
    }

    @Override
    public <E> GenericPager<E> sqlGeneralQueryPager(String sql, Object[] args, RowMapper<E> rowMapper
            , int pageIndex){
        // 2023-04-01 前后端界面使用:PageSearch分页参数，如果存在优先使用。
        PageSearch pageSearch = ListPageContext.getPageSearch();
        if(pageSearch != null){
            return this.sqlGeneralQueryPager(sql, args, rowMapper, pageIndex, pageSearch.getPageSize());
        } else {
            return this.sqlGeneralQueryPager(sql, args, rowMapper, pageIndex, ListPageContext.getCurrentPageSize());
        }
//        return sqlGeneralQueryPager(sql, args, rowMapper, pageIndex, ListPageContext.getCurrentPageSize());
    }

    /**
     * 以动态给定的泛型类型来分页返回数据集合。这样就可以在每个DAO中随时使用。
     * @param sql
     * @param args
     * @param rowMapper
     * @param pageIndex 当前页
     * @param pageSize 分页大小
     * @return
     */
    @Override
    public <T> GenericPager<T> sqlGeneralQueryPager(String sql
            , Object[] args, RowMapper<T> rowMapper, int pageIndex, int pageSize){
//        Assert.hasText(sql);
//        Assert.notNull(rowMapper);
//		int count = jdbcTemplate.queryForInt(getHibernateCountQuery(sql, false), args);
        int count = 0;
        Integer countObj = this.getJdbcTemplate().queryForObject(getJdbcCountQuery(sql, this.getPaginationHelper().getType()), Integer.class, args);
        if(countObj != null){
            count = countObj.intValue();
        }
        GenericPager<T> pager = ListPageContext.createGenericPager(pageIndex, pageSize, count);
//        String querySql = getSqlPagingQuery(sql, (int)pager.getFirstRowIndexInPage(), pageSize);
        String querySql = this.getPaginationHelper().getSqlPagingQuery(sql, null);
        logger.debug("......jdbc分页sql: " + querySql);
        /* 设置分页参数 */
        // 对于sqlserver，分页参数设置修改：第一个和最后一个
        List<T> datas = this.getJdbcTemplate().query(querySql, rowMapper, getSqlPageArgs(args, (int)pager.getFirstRowIndexInPage(), pageSize, this.getPaginationHelper().getType()));
        return pager.setDatas(datas);
    }

    /**
     * 给定统计公式，返回单个统计值。
     * @param sql
     * @param args
     * @param clazz
     * @return
     */
    @Override
    public <T> T sqlMathQuery(String sql, Object[] args, Class<T> clazz){
        return this.getJdbcTemplate().queryForObject(sql, clazz, args);
    }

    /**
     * 查询自定义rowMapper对象，该方法主要使用<code>namedJdbcTemplate</code>来查询。<br>
     * 因为对于有些类似：where in (:ids)的查询必须使用命名参数，使用<code>jdbcTemplate</code>则无法查询
     * @param sql
     * @param rowMapper
     * @param paramSource
     * @return
     */
    @Override
    public <T> List<T> sqlListObjectWhereIn(String sql, RowMapper<T> rowMapper, SqlParameterSource paramSource){
        return this.getNamedParameterJdbcTemplate().query(sql, paramSource, rowMapper);
    }

    /**
     * 该方法主要使用<code>namedJdbcTemplate</code>来查询。<br>
     * @param sql
     * @param paramSource
     * @date 2023-02-05
     */
    public List<Map<String, Object>> queryListObjectWhereIn(String sql, SqlParameterSource paramSource){
        return this.getNamedParameterJdbcTemplate().queryForList(sql, paramSource);
    }

    /**
     * 组装并返回分页需要的参数数组
     * @param args
     * @param firstRowIndex
     * @param pageSize
     * @return
     */
    protected Object[] getSqlPageArgs(Object[] args, int firstRowIndex, int pageSize, String type){
        Object[] params = null;
        int j = 0;

        // sqlserver单独处理吧，就他很扯淡，而且分页参数是拼写死的，所以这里不需要添加分页参数
        if(type.equals(DatabaseType.NAME_SQLSERVER)){
            if(firstRowIndex == 0){
                // sqlserver在offset=0时，hibernate不需要设置该参数！
                if(args == null || args.length == 0){
//					params = new Object[1];
//					params[0] = pageSize;
                } else {
//					params = new Object[args.length+1];
//					params[j] = pageSize;
//					j++;
//					for(; j<args.length+1; j++){
//						params[j] = args[j-1];
//					}
                    params = args;
                }
            } else {
                if(args == null || args.length == 0){
                    params = new Object[2];
                    params[0] = firstRowIndex;
                    params[1] = firstRowIndex + pageSize;
                } else {
                    params = new Object[args.length+2];
//					params[j] = pageSize;
//					j++;
                    for(; j<args.length+1; j++){
                        params[j] = args[j-1];
                    }
//					logger.debug("j = " + j);
//					params[j] = firstRowIndex;
                    params[j] = firstRowIndex;
                    params[j+1] = firstRowIndex + pageSize;
                }
            }
        } else {
            // 其他数据库
            if(args == null || args.length == 0){
                params = new Object[2];
            } else {
                params = new Object[args.length+2];
                for(; j<args.length; j++)
                    params[j] = args[j];
            }
        }

        if(type.equals(DatabaseType.NAME_MYSQL) || type.equals(DatabaseType.NAME_DERBY)){
            params[j] = firstRowIndex;
            params[j+1] = pageSize;
        } else if(type.equals(DatabaseType.NAME_ORACLE)){
            params[j] = firstRowIndex + pageSize;
            params[j+1] = firstRowIndex;
        } else if(type.equals(DatabaseType.NAME_POSTGRES)){
            params[j] = pageSize;
            params[j+1] = firstRowIndex;
        } else if(type.equals(DatabaseType.NAME_SQLSERVER)){
            // 前面已经处理，这里不需要了
//    		throw new UnsupportedOperationException("not implements setPageParameters for sqlserver.");
        }
        return params;
    }

    protected static final String SELECT_COUNT = "select count(*) as dby_num ";
    protected static final String SELECT_COUNT_1 = "from (";
    protected static final String SELECT_COUNT_2 = ") as dby_temp";
    protected static final String SELECT_COUNT_3 = ")";
    protected static final String SQL_FROM = " from";
    public static final String SQL_ORDERBY = "order by";

    /**
     * 返回记录总数
     * @param sql 原始SQL语句
     * @return
     */
    protected String getJdbcCountQuery(String sql, String type){
        if(StringUtils.isEmpty(sql)){
            throw new RuntimeException("sql is required!");
        }

        StringBuilder sb = new StringBuilder();
        sb.append(SELECT_COUNT)
                .append(SELECT_COUNT_1);
        if(type.equals(DatabaseType.NAME_SQLSERVER)){
            // 因为sqlserver字句中不能存在order by，要去掉
            int index = sql.toLowerCase().indexOf(SQL_ORDERBY);
            if(index > 0){
                sb.append(sql.substring(0, index));
            } else {
                sb.append(sql);
            }
        } else {
            sb.append(sql);
        }

        if(type.equals(DatabaseType.NAME_ORACLE) || type.equals(DatabaseType.NAME_DAMENG)){
            // 如果是oracle数据库，那么select count(*) from(...) as dby_num 会报错
            // 需要去掉 as dby_num
            sb.append(SELECT_COUNT_3);
        } else {
            sb.append(SELECT_COUNT_2);
        }
        return sb.toString();
    }

    /**
     * 批量更新数据，不带任何参数
     * @param sql 给定的SQL语句
     */
    public void batchUpdate(String sql){
        batchUpdate(sql, null);
    }

    /**
     * 批量更新数据
     * @param sql 给定的SQL语句
     * @param parameters 参数集合，集合中每个参数都是数组，每次更新使用一个参数
     */
    public void batchUpdate(String sql, final List<Object[]> parameters){
        this.getJdbcTemplate().batchUpdate(sql, new BatchPreparedStatementSetter() {

            @Override
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                if(parameters == null) return;
                Object[] param = parameters.get(i);
                if(param == null || param.length == 0){
                    return;
                }

                int _psize = param.length;
                Object _p = null;
                Class<?> _pc = null;
                for(int j=1; j<_psize+1; j++){
                    _p = param[j-1];
                    if(_p == null){
                        throw new IllegalArgumentException("parameter in arrays can't be null!");
                    }
                    _pc = _p.getClass();

                    if(_pc.isPrimitive()){
                        if(_pc == int.class){
                            ps.setInt(j, ((Integer)_p).intValue());
                        } else if(_pc == float.class){
                            ps.setFloat(j, ((Float)_p).floatValue());
                        } else if(_pc == boolean.class){
                            ps.setBoolean(j, ((Boolean)_p).booleanValue());
                        } else if(_pc == long.class){
                            ps.setLong(j, ((Long)_p).longValue());
                        } else if(_pc == double.class){
                            ps.setDouble(j, ((Double)_p).doubleValue());
                        }
                    } else if(_pc == String.class){
                        ps.setString(j, _p.toString());
                    } else if(_pc == Integer.class){
                        ps.setInt(j, ((Integer)_p).intValue());
                    } else if(_pc == Float.class){
                        ps.setFloat(j, ((Float)_p).floatValue());
                    } else if(_pc == Boolean.class){
                        ps.setBoolean(j, ((Boolean)_p).booleanValue());
                    } else if(_pc == Long.class){
                        ps.setLong(j, ((Long)_p).longValue());
                    } else if(_pc == Double.class){
                        ps.setDouble(j, ((Double)_p).doubleValue());
                    } else {
                        ps.setObject(j, _p);
                    }
                }
            }

            @Override
            public int getBatchSize() {
                if(parameters == null) return 0;
                return parameters.size();
            }
        });
    }

    /**
     * 获取动态拼接SQL的查询条件对象
     * @param temp 业务传入的条件数组
     * @return
     */
    protected Object[] getSearchConditionParams(Object[] temp){
        if(temp == null || temp.length == 0) return null;
        // 加入参数
        List<Object> params = new ArrayList<Object>(temp.length);
        for(Object _p : temp){
            if(_p != null){
                params.add(_p);
            }
        }
        if(params.size() > 0){
            temp = params.toArray();
        } else
            temp = null;
        return temp;
    }
}
