package org.opoo.tools.tablediff;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

import static java.sql.Types.BIGINT;
import static java.sql.Types.BLOB;
import static java.sql.Types.BOOLEAN;
import static java.sql.Types.DATE;
import static java.sql.Types.DECIMAL;
import static java.sql.Types.DOUBLE;
import static java.sql.Types.FLOAT;
import static java.sql.Types.INTEGER;
import static java.sql.Types.SMALLINT;
import static java.sql.Types.TIME;
import static java.sql.Types.TIMESTAMP;
import static java.sql.Types.TINYINT;

@Deprecated
@Slf4j
public abstract class AbstractTableDiffer implements TableDiffer {

    @Override
    public void diff(Comparison comparison, Observer observer) throws SQLException {
        final Table table1 = comparison.getTable1();
        final Table table2 = comparison.getTable2();
        final Columns columns = comparison.getColumns();
        final Id startId = comparison.getStartId();
        try (final Connection conn1 = table1.getDataSource().getConnection();
             final Connection conn2 = table2.getDataSource().getConnection();
             final PreparedStatement ps1 = preparedStatement(conn1, table1.getName(), columns, startId);
             final PreparedStatement ps2 = preparedStatement(conn2, table2.getName(), columns, startId);
             final ResultSet rs1 = ps1.executeQuery();
             final ResultSet rs2 = ps2.executeQuery()) {
            final Context context = buildContext(rs1, columns);
            Status status = Status.CONTINUE;
            while (status == Status.CONTINUE) {
                status = compare(context, rs1, rs2, columns, observer);
            }
        }
    }

    protected PreparedStatement preparedStatement(Connection conn, String tableName, Columns columns, Id startId) throws SQLException {
        final StringBuilder sqlBuilder = new StringBuilder();
        final List<Object> params = new ArrayList<>();

        buildSql(tableName, columns, startId, sqlBuilder, params);

        final String sql = sqlBuilder.toString();
        log.debug("Executing query: {}", sql);
        final PreparedStatement ps = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
        // Streaming 的关键
        ps.setFetchSize(Integer.MIN_VALUE);

        if (startId != null) {
            log.debug("Start id from(exclusive): {}", startId);
            for (int i = 0; i < params.size(); i++) {
                ps.setObject(i + 1, params.get(i));
            }
        }
        return ps;
    }

    protected static void buildSql(String tableName, Columns columns, Id startId, StringBuilder stringBuilder, List<Object> params) {
        final String[] pkNames = columns.getPkNames();
        final String pkNamesJoinString = String.join(", ", pkNames);
        stringBuilder.append("SELECT ")
                .append(pkNamesJoinString)
                .append(", ")
                .append(String.join(", ", columns.getNames()))
                .append(" FROM ")
                .append(tableName);
        // 有开始 ID
        if (startId != null) {
            stringBuilder.append(" WHERE ");
            buildSqlWhere(pkNames, startId.getValues(), stringBuilder, params);
        }
        stringBuilder.append(" ORDER BY ").append(pkNamesJoinString);
    }

    protected static void buildSqlWhere(String[] pkNames, Object[] startIdValues, StringBuilder stringBuilder, List<Object> params) {
        for (int i = 0; i < pkNames.length; i++) {
            if (i > 0) {
                stringBuilder.append(" or ");
                stringBuilder.append("(");
            }
            for (int j = 0; j <= i; j++) {
                if (j > 0) {
                    stringBuilder.append(" and ");
                }
                stringBuilder.append(pkNames[j]);
                params.add(startIdValues[j]);
                if (j < i) {
                    stringBuilder.append(" = ?");
                } else {
                    stringBuilder.append(" > ?");
                }
            }
            if (i > 0) {
                stringBuilder.append(")");
            }
        }
    }

    protected Status compare(Context context, ResultSet rs1, ResultSet rs2, Columns columns, Observer observer) throws SQLException {
        final Id id1 = context.id1 == null ? getNextId(rs1, context) : context.id1;
        final Id id2 = context.id2 == null ? getNextId(rs2, context) : context.id2;
        if (id1 == null && id2 == null) {
            return Status.FINISHED;
        }

        log.debug("id1 = {}, id2 = {}", id1, id2);
        if (id1 == null) {
            // 只有 id1 是 null，rs2 剩下的全是多的，包括 id2
            observer.updateOnlyIn2(id2);
            addRemainIds(rs2, context, observer::updateOnlyIn2);
            return Status.FINISHED;
        }

        if (id2 == null) {
            // 只有id2是null，rs1剩下的全是多的，包括id1
            observer.updateOnlyIn1(id1);
            addRemainIds(rs1, context, observer::updateOnlyIn1);
            return Status.FINISHED;
        }

        final int compareTo = id1.compareTo(id2);
        if (compareTo < 0) {
            // id1 还不够大，对不齐，保存并取下一个id1d1);
            observer.updateOnlyIn1(id1);
            context.update(null, id2);
        } else if (compareTo > 0) {
            // id2 还不够大，对不齐，保存并取下一个id2
            observer.updateOnlyIn2(id2);
            context.update(id1, null);
        } else {
            // ID 是一样大的，则对比其它字段
            if (equals(rs1, rs2, id1, context, columns)) {
                observer.updateIdentical(id1);
            } else {
                observer.updateDifferent(id1);
            }
            context.update(null, null);
        }
        return Status.CONTINUE;
    }

    protected Id getNextId(ResultSet resultSet, Context context) throws SQLException {
        if (resultSet.next()) {
            final int pkNamesCount = context.getPkNamesCount();
            final Object[] id = new Object[pkNamesCount];
            for (int i = 0; i < pkNamesCount; i++) {
                id[i] = resultSet.getObject(i + 1, context.getPkTypes()[i]);
            }
            return new Id(id);
        }
        return null;
    }

    protected Context buildContext(ResultSet resultSet, Columns columns) throws SQLException {
        final ResultSetMetaData metaData = resultSet.getMetaData();
        final int pkNamesCount = columns.getPkNames().length;
        Class<?>[] pkTypes = columns.getPkTypes();
        Class<?>[] types = columns.getTypes();

        if (pkTypes == null) {
            pkTypes = new Class<?>[pkNamesCount];
            for (int i = 0; i < pkNamesCount; i++) {
                pkTypes[i] = sqlTypeToClass(metaData.getColumnType(i + 1));
            }
        }

        if (types == null) {
            final int namesCount = columns.getNames().length;
            types = new Class<?>[namesCount];
            for (int i = 0; i < namesCount; i++) {
                types[i] = sqlTypeToClass(metaData.getColumnType(pkNamesCount + i + 1));
            }
        }
        log.debug("pk types = {}, other column types ={}", pkTypes, types);
        return new Context(pkTypes, types);
    }

    protected Class<?> sqlTypeToClass(int sqlType) {
        switch (sqlType) {
            case BOOLEAN:
                return Boolean.class;

            case TINYINT:
            case SMALLINT:
            case INTEGER:
                return Integer.class;

            case BIGINT:
                return Long.class;

            case DECIMAL:
                return BigDecimal.class;

            case FLOAT:
                return Float.class;

            case DOUBLE:
                return Double.class;

            case DATE:
            case TIME:
            case TIMESTAMP:
                return Date.class;

            case BLOB:
                return byte[].class;

            default:
                return String.class;
        }
    }

    protected void addRemainIds(ResultSet rs, Context context, Consumer<Id> idConsumer) throws SQLException {
        for (; ; ) {
            final Id nextId = getNextId(rs, context);
            if (nextId == null) {
                break;
            }
            idConsumer.accept(nextId);
        }
    }

    protected boolean equals(ResultSet rs1, ResultSet rs2, final Id id, Context context, Columns columns) throws SQLException {
        final int pkNamesCount = context.getPkNamesCount();
        final int namesCount = context.getNamesCount();
        final Class<?>[] columnTypes = context.getTypes();
        for (int i = 0; i < namesCount; i++) {
            final Class<?> columnType = columnTypes[i];
            final Object val1 = rs1.getObject(pkNamesCount + i + 1, columnType);
            final Object val2 = rs2.getObject(pkNamesCount + i + 1, columnType);
            if (!Objects.equals(val1, val2)) {
                log.info("Difference found: id = {}, column name = '{}', value1 ='{}', value2 = '{}'", id, columns.getNames()[i], val1, val2);
                return false;
            }
        }
        return true;
    }

    @Data
    protected static class Context {
        private final Class<?>[] pkTypes;
        private final Class<?>[] types;
        private final int pkNamesCount;
        private final int namesCount;

        protected Context(Class<?>[] pkTypes, Class<?>[] types) {
            this.pkTypes = pkTypes;
            this.types = types;
            pkNamesCount = pkTypes.length;
            namesCount = types.length;
        }

        private Id id1;
        private Id id2;

        public void update(Id id1, Id id2) {
            this.id1 = id1;
            this.id2 = id2;
        }
    }

    protected enum Status {
        FINISHED, CONTINUE;
    }
}
