package org.opoo.tools.db.diff;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.opoo.tools.db.Column;
import org.opoo.tools.db.Id;
import org.opoo.tools.db.SqlEqualizer;
import org.opoo.tools.db.Table;
import org.opoo.tools.db.TableInput;
import org.opoo.tools.db.util.DbUtils;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.util.StringUtils;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Objects;
import java.util.stream.Collectors;

@Slf4j
public class SimpleTableComparator implements TableComparator {

    @Override
    public void compare(TableInput tableInputA, TableInput tableInputB, SqlEqualizer<ResultSet> equalizer, TableListener listener)
            throws SQLException {
        final Table tableA = tableInputA.getTable();
        final Table tableB = tableInputB.getTable();
        validate(tableA, tableB);

        try (final PreparedStatement psA = preparedStatement(tableA, tableInputA.getConnection());
             final PreparedStatement psB = preparedStatement(tableB, tableInputB.getConnection());
             final ResultSet rsA = psA.executeQuery();
             final ResultSet rsB = psB.executeQuery()) {
            final Context context = createContext(tableA, tableB, rsA, rsB);
            Status status = Status.CONTINUE;
            while (status == Status.CONTINUE) {
                status = compareInternal(context, equalizer, listener);
            }
        }
    }

    protected Context createContext(Table tableA, Table tableB, ResultSet rsA, ResultSet rsB) throws SQLException {
        // 允许输入的查询中列或者表达式中没有类型，在这里进行初始化
        DbUtils.initializeColumnTypes(rsA, tableA.getColumns());
        DbUtils.initializeColumnTypes(rsB, tableB.getColumns());
        return new Context(tableA, tableB, rsA, rsB).moveNext();
    }

    protected PreparedStatement preparedStatement(Table table, Connection conn) throws SQLException {
        final String sql = buildSql(table);
        log.debug("查询: {}", sql);
        return preparedStatement(conn, sql);
    }

    protected String buildSql(Table table) {
        final StringBuilder stringBuilder = new StringBuilder();
        final String primaryKeysJoinString = Arrays.stream(table.getPrimaryKeyColumns())
                .map(Column::getName)
                .collect(Collectors.joining(", "));
        final String columnsJoinString = Arrays.stream(table.getColumns())
                .map(Column::getName)
                .collect(Collectors.joining(", "));

        stringBuilder.append("SELECT ")
                .append(columnsJoinString)
                .append(" FROM ")
                .append(table.getName());

        if (StringUtils.hasText(table.getWhereCondition())) {
            stringBuilder.append(" WHERE ").append(table.getWhereCondition());
        }

        stringBuilder.append(" ORDER BY ").append(primaryKeysJoinString);
        return stringBuilder.toString();
    }

    protected PreparedStatement preparedStatement(Connection conn, String sql) throws SQLException {
        final DatabaseMetaData databaseMetaData = conn.getMetaData();
        final String databaseProductName = databaseMetaData.getDatabaseProductName();
        final PreparedStatement ps = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
        if (databaseProductName.startsWith("Oracle")) {
            ps.setFetchSize(10000);
        } else if (databaseProductName.startsWith("MySQL") || databaseProductName.startsWith("MariaDB")) {
            // Streaming 的关键
            ps.setFetchSize(Integer.MIN_VALUE);
        }
        return ps;
    }

    protected Status compareInternal(Context context, SqlEqualizer<ResultSet> equalizer, TableListener listener) throws SQLException {
        final Id idA = context.getIdA();
        final Id idB = context.getIdB();

        if (idA == null && idB == null) {
            return Status.FINISHED;
        }

        log.debug("Comparing idA = {}, idB = {}", idA, idB);
        if (idA == null) {
            // 只有 A 是 null，则 B 剩下的全是 B 独有的，包括当前的 idB
            listener.onOnlyInB(idB);
            while (context.nextIdB()) {
                listener.onOnlyInB(context.getIdB());
            }
            return Status.FINISHED;
        }

        if (idB == null) {
            // 只有 B 是 null，A 剩下的全是 A 独有的，包括 idA
            listener.onOnlyInA(idA);
            while (context.nextIdA()) {
                listener.onOnlyInA(context.getIdA());
            }
            return Status.FINISHED;
        }

        final int compareTo = idA.compareTo(idB);
        if (compareTo < 0) {
            // idA 还不够大，对不齐，保存并取下一个 idA;
            listener.onOnlyInA(idA);
            context.nextIdA();
            return Status.CONTINUE;
        }

        if (compareTo > 0) {
            // idB 还不够大，对不齐，保存并取下一个 idB
            listener.onOnlyInB(idB);
            context.nextIdB();
            return Status.CONTINUE;
        }

        // ID 相同时，比较其他非ID字段
        compareNormalColumns(context, idA, equalizer, listener);
        return Status.CONTINUE;
    }

    protected void compareNormalColumns(Context context, Id id, SqlEqualizer<ResultSet> equalizer, TableListener listener) throws SQLException {
        boolean isNormalColumnsEqual;
        // 指定了自定义的比较器
        if (equalizer != null) {
            isNormalColumnsEqual = equalizer.equals(context.getResultSetA(), context.getResultSetB());
        } else {
            isNormalColumnsEqual = isNormalColumnsEqual(context);
        }

        // 其它字段也相同
        if (isNormalColumnsEqual) {
            listener.onIdentical(id);
        } else {
            listener.onDifferent(id);
        }
        // 两个ID都向前移动
        context.moveNext();
    }

    protected boolean isNormalColumnsEqual(Context context) throws SQLException {
        final Table tableA = context.getTableA();
        final Table tableB = context.getTableB();
        final Column[] normalColumnsA = tableA.getNormalColumns();
        final Column[] normalColumnsB = tableB.getNormalColumns();
        final int length = tableA.getPrimaryKeyColumns().length;
        for (int i = 0; i < normalColumnsA.length; i++) {
            Column columnA = normalColumnsA[i];
            Column columnB = normalColumnsB[i];
            final Object valueA = JdbcUtils.getResultSetValue(context.getResultSetA(), length + i + 1, columnA.getType());
            final Object valueB = JdbcUtils.getResultSetValue(context.getResultSetB(), length + i + 1, columnB.getType());
            if (!Objects.deepEquals(valueA, valueB)) {
                log.info("Difference found: id = {}, tableA column name = '{}', valueA ='{}', tableB column name = '{}', valueB = '{}'",
                        context.getIdA(), columnA.getName(), valueA, columnB.getName(), valueB);
                return false;
            }
        }

        return true;
    }

    private void validate(Table tableA, Table tableB) {
        if (tableA.getColumns().length != tableB.getColumns().length) {
            throw new IllegalArgumentException("tableA and tableB must have the same number of columns");
        }

        if (tableA.getPrimaryKeyColumns().length != tableB.getPrimaryKeyColumns().length) {
            throw new IllegalArgumentException("tableA and tableB must have the same number of primary key columns");
        }
    }

    @Data
    protected static class Context {
        private final Table tableA;
        private final Table tableB;
        private final ResultSet resultSetA;
        private final ResultSet resultSetB;

        private Id idA;
        private Id idB;

        protected Context moveNext() throws SQLException {
            nextIdA();
            nextIdB();
            return this;
        }

        protected boolean nextIdA() throws SQLException {
            idA = DbUtils.getNextId(tableA.getPrimaryKeyColumns(), resultSetA);
            return idA != null;
        }

        protected boolean nextIdB() throws SQLException {
            idB = DbUtils.getNextId(tableB.getPrimaryKeyColumns(), resultSetB);
            return idB != null;
        }
    }

    protected enum Status {
        FINISHED, CONTINUE
    }
}
