package org.opoo.tools.db.copy;

import lombok.extern.slf4j.Slf4j;
import org.opoo.tools.db.ArgumentSetter;
import org.opoo.tools.db.Column;
import org.opoo.tools.db.Id;
import org.opoo.tools.db.SqlAndParams;
import org.opoo.tools.db.Table;
import org.opoo.tools.db.util.DbUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

@Slf4j
public abstract class AbstractSameDbTableCopier implements SameDbTableCopier {

    @Override
    public int copy(Connection connection, Table sourceTable, Table targetTable, int batchSize) throws SQLException {
        validate(sourceTable, targetTable);
        Id previousOffsetId = null;
        int count = 0;
        do {
            final Id offsetId = getOffsetId(connection, sourceTable, batchSize, previousOffsetId);
            count += batchCopy(connection, sourceTable, targetTable, previousOffsetId, offsetId);
            if (offsetId == null) {
                break;
            }
            previousOffsetId = offsetId;
        } while (true);
        log.info("复制表 {} => {} 完成，更新数据 {} 项", sourceTable.getName(), targetTable.getName(), count);
        return count;
    }

    /**
     * 查询区间上限的主键。
     *
     * @param connection       SQL 连接
     * @param table            表
     * @param batchSize        批次大小
     * @param previousOffsetId 上一批次区间上限，如果当前是第一批次，则上一批次上限为 null
     * @return 本批次上限的主键，如果不足一个批次，返回 null
     * @throws SQLException 查询过程出现的 SQL 异常
     */
    protected Id getOffsetId(Connection connection, Table table, int batchSize, Id previousOffsetId) throws SQLException {
        final SqlAndParams sqlAndParams = buildGetOffsetIdSql(table, batchSize, previousOffsetId);
        final String sql = sqlAndParams.getSql();
        final List<Object> params = sqlAndParams.getParams();
        log.debug("查询区间上限: {}", sql);

        try (final PreparedStatement ps = connection.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)) {
            new ArgumentSetter(params).setValues(ps);
            try (final ResultSet resultSet = ps.executeQuery()) {
                return DbUtils.getNextId(table, resultSet);
            }
        }
    }

    /**
     * 复制一个批次，该批次主键的区间为开闭区间 (previousOffsetId, offsetId]。
     *
     * @param connection       SQL 连接
     * @param sourceTable      源表
     * @param targetTable      目标表
     * @param previousOffsetId 上一批次的区间上限，如果当前是第一批次，则上一批次上限为 null
     * @param offsetId         本批次的区间上限，如果当前不足一个批次，则为 null
     * @return 复制的记录数
     * @throws SQLException 复制过程出现的 SQL 异常
     */
    protected int batchCopy(Connection connection, Table sourceTable, Table targetTable, Id previousOffsetId, Id offsetId) throws SQLException {
        // 复制一个批次，ID的区间为开闭区间 (previousOffsetId, offsetId]
        final SqlAndParams sqlAndParams = buildCopySql(sourceTable, targetTable, previousOffsetId, offsetId);
        final String sql = sqlAndParams.getSql();
        log.debug("区间复制 ({}, {}]: {}", previousOffsetId, offsetId, sql);
        try (final PreparedStatement ps = connection.prepareStatement(sql)) {
            new ArgumentSetter(sqlAndParams.getParams()).setValues(ps);
            return ps.executeUpdate();
        }
    }

    /**
     * 构建获取当前批次区间上限的 SQL 语句和对应的参数集合。
     *
     * @param table            表
     * @param batchSize        批次大小
     * @param previousOffsetId 上一批次的区间上限，如果当前是第一批次，则上一批次上限为 null
     * @return SQL语句和参数集合
     */
    protected abstract SqlAndParams buildGetOffsetIdSql(Table table, int batchSize, Id previousOffsetId);

    /**
     * 构建复制一个批次的 SQL 语句和对应的参数集合。
     *
     * @param sourceTable      源表
     * @param targetTable      目标表
     * @param previousOffsetId 上一批次的区间上限，如果当前是第一批次，则上一批次上限为 null
     * @param offsetId         当前批次的区间上限，如果当前不足一个批次，则为 null
     * @return SQL语句和参数集合
     */
    protected SqlAndParams buildCopySql(Table sourceTable, Table targetTable, @Nullable Id previousOffsetId, @Nullable Id offsetId) {
        final String[] primaryKeyNames = Arrays.stream(sourceTable.getPrimaryKeyColumns()).map(Column::getName).toArray(String[]::new);
        final String whereCondition = sourceTable.getWhereCondition();
        final List<Object> params = new ArrayList<>();

        final StringBuilder sql = new StringBuilder()
                .append("INSERT INTO ")
                .append(targetTable.getName())
                .append("(")
                .append(Arrays.stream(targetTable.getColumns()).map(Column::getName).collect(Collectors.joining(", ")))
                .append(") SELECT ")
                .append(Arrays.stream(sourceTable.getColumns()).map(Column::getName).collect(Collectors.joining(", ")))
                .append(" FROM ")
                .append(sourceTable.getName());
        boolean wherePresent = false;
        if (previousOffsetId != null) {
            final SqlAndParams idCondition = DbUtils.buildGreaterThanCondition(primaryKeyNames, previousOffsetId.getValues());
            sql.append(" WHERE ").append(idCondition.getSql());
            params.addAll(idCondition.getParams());
            wherePresent = true;
        }
        if (offsetId != null) {
            final SqlAndParams offsetIdCondition = DbUtils.buildLessThanOrEqualsCondition(primaryKeyNames, offsetId.getValues());
            sql.append(wherePresent ? " AND " : " WHERE ").append(offsetIdCondition.getSql());
            params.addAll(offsetIdCondition.getParams());
            wherePresent = true;
        }
        if (StringUtils.hasText(whereCondition)) {
            if (wherePresent) {
                sql.append(" AND (").append(whereCondition).append(")");
            } else {
                sql.append(" WHERE ").append(whereCondition);
            }
        }

        return new SqlAndParams(sql.toString(), params);
    }

    /**
     * 校验传入的参数。
     */
    private void validate(Table sourceTable, Table targetTable) {
        if (sourceTable.getColumns().length != targetTable.getColumns().length) {
            throw new IllegalArgumentException("sourceTable and targetTable must have the same number of columns");
        }

        if (sourceTable.getPrimaryKeyColumns().length != targetTable.getPrimaryKeyColumns().length) {
            throw new IllegalArgumentException("sourceTable and targetTable must have the same number of primary key columns");
        }
    }
}
