package org.opoo.tools.db.copy;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.opoo.tools.db.Column;
import org.opoo.tools.db.SqlBiConsumer;
import org.opoo.tools.db.Table;
import org.opoo.tools.db.TableInput;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.util.StringUtils;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.stream.Collectors;

@Slf4j
public class SimpleTableCopier implements TableCopier {

    @Override
    public int copy(TableInput source, TableInput target, int batchSize, SqlBiConsumer<ResultSet, PreparedStatement> recordCopier)
            throws SQLException {
        final Table sourceTable = source.getTable();
        final Table targetTable = target.getTable();
        final CopyParams copyParams = new CopyParams(sourceTable, targetTable, batchSize, recordCopier);

        final String querySql = buildQuerySql(sourceTable);
        final String insertSql = buildInsertSql(sourceTable, targetTable);
        log.info("查询: {}", querySql);
        log.info("写入: {}", insertSql);

        try (final PreparedStatement queryPs = createQueryStatement(source.getConnection(), querySql);
             final PreparedStatement insertPs = createInsertStatement(target.getConnection(), insertSql);
             final ResultSet rs = queryPs.executeQuery()) {
            log.info("复制 {} => {}, ResultSet 到 PreparedStatement ...", sourceTable.getName(), targetTable.getName());
            return copyAllRecords(copyParams, rs, insertPs);
        }
    }

    protected String buildQuerySql(Table table) {
        final String allComlumnsString = Arrays.stream(table.getColumns()).map(Column::getName).collect(Collectors.joining(", "));
        final String primaryKeysJoinString = Arrays.stream(table.getPrimaryKeyColumns()).map(Column::getName).collect(Collectors.joining(", "));
        final String whereCondition = table.getWhereCondition();

        final StringBuilder sql = new StringBuilder()
                .append("SELECT ")
                .append(allComlumnsString)
                .append(" FROM ")
                .append(table.getName());

        if (StringUtils.hasText(whereCondition)) {
            sql.append(" WHERE ").append(whereCondition);
        }

        sql.append(" ORDER BY ").append(primaryKeysJoinString);

        return sql.toString();
    }

    protected String buildInsertSql(Table sourceTable, Table targetTable) {
        final Column[] columns = targetTable.getColumns();
        final String[] questionMarks = new String[columns.length];
        Arrays.fill(questionMarks, "?");

        return "INSERT INTO " +
                targetTable.getName() +
                " (" +
                Arrays.stream(columns).map(Column::getName).collect(Collectors.joining(", ")) +
                ") VALUES (" +
                String.join(", ", questionMarks) +
                ")";
    }

    protected PreparedStatement createQueryStatement(Connection conn, String sql) throws SQLException {
        final DatabaseMetaData databaseMetaData = conn.getMetaData();
        final String databaseProductName = databaseMetaData.getDatabaseProductName();
        final PreparedStatement ps = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
        if (databaseProductName.startsWith("Oracle")) {
            ps.setFetchSize(10000);
        } else if (databaseProductName.startsWith("MySQL") || databaseProductName.startsWith("MariaDB")) {
            // Streaming 的关键，MySQL/MariaDB/TDSQL for MySQL 等
            ps.setFetchSize(Integer.MIN_VALUE);
        }
        return ps;
    }

    protected PreparedStatement createInsertStatement(Connection conn, String sql) throws SQLException {
        return conn.prepareStatement(sql);
    }

    protected int copyAllRecords(final CopyParams params, final ResultSet rs, final PreparedStatement ps) throws SQLException {
        final int batchSize = params.getBatchSize();

        long start = System.currentTimeMillis();
        int total = 0;
        int count = 0;
        while (rs.next()) {
            copyRecord(params, rs, ps);
            ps.addBatch();
            count++;

            if (count >= batchSize) {
                ps.executeBatch();
                total += count;
                count = 0;
                logBatch(start, total);
            }
        }

        if (count > 0) {
            ps.executeBatch();
            total += count;
            logBatch(start, total);
        }
        log.info("复制 {} => {} 完成，共 {} 项", params.getSourceTable().getName(), params.getTargetTable().getName(), total);
        return total;
    }

    protected void copyRecord(final CopyParams params, final ResultSet rs, final PreparedStatement ps) throws SQLException {
        final SqlBiConsumer<ResultSet, PreparedStatement> recordCopier = params.getRecordCopier();
        if (recordCopier != null) {
            // 使用自定义的处理器来复制一条记录
            recordCopier.apply(rs, ps);
            return;
        }

        final Table sourceTable = params.getSourceTable();
        final Table targetTable = params.getTargetTable();
        final Column[] columns1 = sourceTable.getColumns();
        final Column[] columns2 = targetTable.getColumns();
        for (int i = 0; i < columns1.length; i++) {
            final Object value = JdbcUtils.getResultSetValue(rs, i + 1, columns1[i].getType());
            StatementCreatorUtils.setParameterValue(ps, i + 1, columns2[i].getSqlType(), value);
        }
    }

    private void logBatch(long start, int total) {
        log.debug("批次复制完成，当前记录 {}, 耗时 {}ms", total, System.currentTimeMillis() - start);
    }

    @Data
    protected static class CopyParams {
        private final Table sourceTable;
        private final Table targetTable;
        private final int batchSize;
        private final SqlBiConsumer<ResultSet, PreparedStatement> recordCopier;
    }
}
