package org.accidia.echo.mysql.keyvalue;

import com.google.common.base.Strings;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import org.accidia.echo.dao.IProtobufDao;
import org.accidia.echo.mysql.MySqlDataSource;
import org.accidia.echo.protos.Protos.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.support.JdbcDaoSupport;

import java.io.IOException;
import java.util.*;

import static com.google.common.base.Preconditions.checkArgument;

public class MySqlKeyValueProtobufDao extends JdbcDaoSupport
        implements IProtobufDao {

    private static final Logger logger = LoggerFactory.getLogger(MySqlKeyValueProtobufDao.class);
    private final Message messageDefaultInstance;
    private final MySqlKeyValueProtobufRowMapper mySqlProtobufRowMapper = new MySqlKeyValueProtobufRowMapper();
    private final RowMapper<Message> rowMapper = (resultSet, rowNum)
            -> this.mySqlProtobufRowMapper.mapResultSet(resultSet, getMessageBuilder());
    private final DataSource mySqlDataSource;

    public static MySqlKeyValueProtobufDao newInstance(final Message messageDefaultInstancee,
                                                       final DataSource dataSource) {
        try {
            return new MySqlKeyValueProtobufDao(messageDefaultInstancee, dataSource).createTablesIfNotExist();
        } catch (Descriptors.DescriptorValidationException | ReflectiveOperationException | IOException e) {
            logger.warn("failed to create new mysql protobuf dao instance:", e);
            throw new RuntimeException(e);
        }
    }

    protected MySqlKeyValueProtobufDao(final Message messageDefaultInstance, final DataSource dataSource)
            throws Descriptors.DescriptorValidationException, ReflectiveOperationException, IOException {
        logger.debug("MySqlProtobufDao()");

        this.messageDefaultInstance = messageDefaultInstance;
        this.mySqlDataSource = dataSource;
        setDataSource(MySqlDataSource.getInstance(dataSource).getConnectoinPoolDataSource());
    }

    @Override
    public Message findByKey(final String key, boolean includeArchive) {
        logger.debug("findByKey()");
        checkArgument(!Strings.isNullOrEmpty(key), "null/empty key");
        return doFindByKey(key, includeArchive);
    }

    @Override
    public Message findFieldsByKey(final String key, final List<String> fieldsIgnored, boolean includeArchive) {
        logger.debug("findFieldsByKey()");
        checkArgument(!Strings.isNullOrEmpty(key), "null/empty key");
        return doFindByKey(key, includeArchive);
    }

    @Override
    public List<String> findList(final String listKey, final int start, final int count) {
        logger.debug("findList()");
        checkArgument(!Strings.isNullOrEmpty(listKey), "null/empty listkey");
        checkArgument(start >= 0, "invalid start");
        checkArgument(count >= -1, "invalid count");
        return doFindList(listKey, start, count, " ORDER BY `TIMESTAMP` DESC ");
    }

    @Override
    public List<String> findAllList(final String listKey) {
        return findList(listKey, 0, -1);
    }

    @Override
    public List<Message> findListObjects(final String listKey, final int start, final int count, boolean includeArchive) {
        logger.debug("findListObjects()");
        checkArgument(!Strings.isNullOrEmpty(listKey), "null/empty listkey");
        checkArgument(start >= 0, "invalid start");
        checkArgument(count >= -1, "invalid count");
        return doFindListObjects(listKey, start, count, " ORDER BY `TIMESTAMP` DESC ", includeArchive);
    }

    @Override
    public List<Message> findAllListObjects(final String listKey, boolean includeArchive) {
        return findListObjects(listKey, 0, -1, includeArchive);
    }

    @Override
    public List<Message> findOrderedListObjectsAscending(final String listKey, final int start, final int count, boolean includeArchive) {
        return doFindListObjects(listKey, start, count, " ORDER BY `OBJECT_WEIGHT` ASC, `TIMESTAMP` DESC ", includeArchive);
    }

    @Override
    public List<Message> findAllOrderedListObjectsAscending(final String listKey, boolean includeArchive) {
        return findOrderedListObjectsAscending(listKey, 0, -1, includeArchive);
    }

    @Override
    public List<Message> findOrderedListObjectsDescending(final String listKey, final int start, final int count, boolean includeArchive) {
        return doFindListObjects(listKey, start, count, " ORDER BY `OBJECT_WEIGHT` DESC, `TIMESTAMP` DESC ", includeArchive);
    }

    @Override
    public List<Message> findAllOrderedListObjectsDescending(final String listKey, boolean includeArchive) {
        return findOrderedListObjectsDescending(listKey, 0, -1, includeArchive);
    }

    @Override
    public List<String> findOrderedListAscending(final String listKey, final int start, final int count) {
        return doFindList(listKey, start, count, " ORDER BY `OBJECT_WEIGHT` ASC, `TIMESTAMP` DESC ");
    }

    @Override
    public List<String> findAllOrderedListAscending(final String listKey) {
        return findOrderedListAscending(listKey, 0, -1);
    }

    @Override
    public List<String> findOrderedListDescending(final String listKey, final int start, final int count) {
        return doFindList(listKey, start, count, " ORDER BY `OBJECT_WEIGHT` DESC, `TIMESTAMP` DESC ");
    }

    @Override
    public List<String> findAllOrderedListDescending(final String listKey) {
        return findOrderedListDescending(listKey, 0, -1);
    }

    @Override
    public void store(final String key, final Message object) {
        logger.debug("store()");
        checkArgument(!Strings.isNullOrEmpty(key), "null/empty key");
        checkArgument(object != null, "null object");
        doStore(key, object, getObjectTableName());
    }

    @Override
    public void addToList(final String listKey, final String objectKey) {
        checkArgument(!Strings.isNullOrEmpty(listKey), "null/empty listKey");
        checkArgument(!Strings.isNullOrEmpty(objectKey), "null/empty objectKey");
        doAddToList(listKey, objectKey);
    }

    @Override
    public void removeFromList(String listKey, String objectKey) {
        checkArgument(!Strings.isNullOrEmpty(listKey), "null/empty listKey");
        checkArgument(!Strings.isNullOrEmpty(objectKey), "null/empty objectKey");
        doRemoveFromList(listKey, objectKey);
    }

    @Override
    public void addToOrderedList(final String listKey, final String objectKey, final long weight) {
        checkArgument(!Strings.isNullOrEmpty(listKey), "null/empty listKey");
        checkArgument(!Strings.isNullOrEmpty(objectKey), "null/empty objectKey");
        doAddToOrderedList(listKey, objectKey, weight);
    }

    @Override
    public void storeOrUpdate(final String key, final Message object) {
        // TODO this should be different from store
        store(key, object);
    }

    @Override
    public void archive(final String key) {
        checkArgument(!Strings.isNullOrEmpty(key), "null/empty key");
        doArchive(key);
    }

    @Override
    public void unArchive(final String key) {
        checkArgument(!Strings.isNullOrEmpty(key), "null/empty key");
        doUnArchive(key);
    }

    @Override
    public Message getMessageDefaultInstance() {
        return this.messageDefaultInstance;
    }

    protected void doArchive(final String key) {
        // move the object to the archive table
        final Message message = findByKey(key, false);
        checkArgument(message != null, "invalid message");
        doStore(key, message, getArchiveTableName());
        doDelete(key, getObjectTableName());
    }

    protected void doUnArchive(final String key) {
        // move the object to the archive table
        final Message message = findByKey(key, true);
        checkArgument(message != null, "invalid message");
        doStore(key, message, getObjectTableName());
        doDelete(key, getArchiveTableName());
    }

    protected Message doFindByKey(final String key, boolean includeArchive) {
        final String sql = "SELECT `KEY`,`VALUE` FROM `" + (includeArchive ? getAllViewName() : getObjectTableName()) + "` WHERE `KEY` = ? LIMIT 1";
        final List<Message> messages = getJdbcTemplate().query(sql, getRowMapper(), key);
        if (messages == null || messages.isEmpty()) {
            return null;
        }
        return messages.get(0);
    }

    protected void doStore(final String key, final Message object, final String tableName) {
        final String sql = "REPLACE INTO `" + tableName + "` SET `KEY` = ?, `VALUE` = ?";
        getJdbcTemplate().update(sql, key, messageToBytes(object));
    }

    protected void doDelete(final String key, final String tableName) {
        final String sql = "DELETE FROM `" + tableName + "` WHERE `KEY` = ?";
        getJdbcTemplate().update(sql, key);
    }

    protected List<String> doFindList(final String listKey, final int start, final int count, final String orderby) {
        final StringBuilder sqlStringBuilder = new StringBuilder();
        final List<Object> parameters = new ArrayList<>();
        parameters.add(listKey);
        sqlStringBuilder.append("SELECT `OBJECT_KEY` FROM `")
                .append(getListTableName())
                .append("` WHERE `LIST_KEY` = ? ")
                .append(orderby);

        final String sql;
        if (start == 0 && count == -1) {
            sql = sqlStringBuilder.toString();
        } else {
            sqlStringBuilder.append("LIMIT ?, ?");
            parameters.add(start);
            parameters.add(count);
            sql = sqlStringBuilder.toString();
        }
        logger.debug("sql to run is {} and list key is {}", sql, listKey);

        final List<String> objectKeys = getJdbcTemplate().query(sql, (rs, rowNum) -> rs.getString("OBJECT_KEY"), parameters.toArray());
        logger.debug("object keys: {}", objectKeys);
        return objectKeys != null ? objectKeys : Collections.emptyList();
    }

    protected List<Message> doFindListObjects(final String listKey, final int start, final int count, final String orderby, boolean includeArchive) {
        final StringBuilder sqlStringBuilder = new StringBuilder();
        sqlStringBuilder.append("SELECT `").append(includeArchive ? getAllViewName() : getObjectTableName()).append("`.`VALUE`,`").append(getListTableName()).append("`.`TIMESTAMP`")
                .append(" FROM `").append(includeArchive ? getAllViewName() : getObjectTableName()).append("`,`").append(getListTableName()).append("` ")
                .append(" WHERE ")
                .append("`").append(includeArchive ? getAllViewName() : getObjectTableName()).append("`").append(".`KEY` = `")
                .append(getListTableName()).append("`.`OBJECT_KEY` ")
                .append(" AND ")
                .append("`").append(getListTableName()).append("`").append(".`LIST_KEY` = ? ")
                .append(orderby);

        final String sql;
        final List<Message> messageList;
        if (start == 0 && count == -1) {
            sql = sqlStringBuilder.toString();
            messageList = getJdbcTemplate().query(sql, getRowMapper(), listKey);
        } else {
            sqlStringBuilder.append("LIMIT ?, ?");
            sql = sqlStringBuilder.toString();
            messageList = getJdbcTemplate().query(sql, getRowMapper(), listKey, start, count);
        }
        logger.debug("sql to run is {}", sql);

        return messageList != null ? messageList : Collections.emptyList();
    }

    protected void doAddToList(final String listKey, final String objectKey) {
        final StringBuilder sqlStringBuilder = new StringBuilder();
        sqlStringBuilder.append("REPLACE INTO ")
                .append("`").append(getListTableName()).append("`")
                .append(" SET `LIST_KEY` = ?, `OBJECT_KEY` = ?");
        getJdbcTemplate().update(sqlStringBuilder.toString(), listKey, objectKey);
    }

    protected void doRemoveFromList(final String listKey, final String objectKey) {
        final StringBuilder sqlStringBuilder = new StringBuilder();
        sqlStringBuilder.append("DELETE FROM ")
                .append("`").append(getListTableName()).append("`")
                .append(" WHERE `LIST_KEY` = ? AND `OBJECT_KEY` = ?");
        getJdbcTemplate().update(sqlStringBuilder.toString(), listKey, objectKey);
    }

    protected void doAddToOrderedList(final String listKey, final String objectKey, final long weight) {
        final StringBuilder sqlStringBuilder = new StringBuilder();
        sqlStringBuilder.append("REPLACE INTO ")
                .append("`").append(getListTableName()).append("`")
                .append(" SET `LIST_KEY` = ?, `OBJECT_KEY` = ?, `OBJECT_WEIGHT` = ?");
        getJdbcTemplate().update(sqlStringBuilder.toString(), listKey, objectKey, weight);
    }

    protected MySqlKeyValueProtobufDao createTablesIfNotExist() {
        final String sql = "CREATE TABLE IF NOT EXISTS `" + getObjectTableName() +
                "`( `KEY` VARCHAR(128) NOT NULL, " +
                " `VALUE` LONGBLOB NOT NULL, " +
                " `TIMESTAMP` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, " +
                "  PRIMARY KEY (`KEY`) ) " +
                "  ENGINE=InnoDB DEFAULT CHARSET=utf8";
        
        final String archiveSql = "CREATE TABLE IF NOT EXISTS `" + getArchiveTableName() +
                "` ( `KEY` varchar(128) NOT NULL, " +
                " `VALUE` LONGBLOB NOT NULL, " +
                " `TIMESTAMP` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP," +
                "  PRIMARY KEY (`KEY`) ) " +
                "  ENGINE=InnoDB DEFAULT CHARSET=utf8";

        final String listSql = "CREATE TABLE IF NOT EXISTS `" + getListTableName() +
                "` ( `LIST_KEY` VARCHAR(128) NOT NULL, " +
                " `OBJECT_KEY` VARCHAR(128) NOT NULL, " +
                " `OBJECT_WEIGHT` BIGINT, " +
                " `TIMESTAMP` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, " +
                "  PRIMARY KEY (`LIST_KEY`,`OBJECT_KEY`) ) " +
                "  ENGINE=InnoDB DEFAULT CHARSET=utf8";

        final String allViewSql = "CREATE OR REPLACE VIEW " +
                " `" + getAllViewName() + "` AS " +
                " SELECT * FROM `" + getObjectTableName() + "` " +
                " UNION " +
                " SELECT * FROM `" + getArchiveTableName() + "` ;";

        getJdbcTemplate().update(sql);
        getJdbcTemplate().update(archiveSql);
        getJdbcTemplate().update(listSql);
        getJdbcTemplate().update(allViewSql);
        return this;
    }

    protected Message.Builder getMessageBuilder() {
        return this.messageDefaultInstance.newBuilderForType();
    }

    // by conventions, table names must be the same as the message class name
    protected String getObjectTableName() {
        return getMessageDefaultInstance().getClass().getSimpleName().toUpperCase();
    }

    // by conventions, table names must be the same as the message class name
    protected String getListTableName() {
        return getObjectTableName() + "_LIST";
    }

    protected String getArchiveTableName() {
        return getObjectTableName() + "_ARCHIVE";
    }

    protected String getAllViewName() {
        return getObjectTableName() + "_ALL_VIEW";
    }

    protected byte[] messageToBytes(final Message message) {
        return message.toByteArray();
    }

    protected RowMapper<Message> getRowMapper() {
        return this.rowMapper;
    }

    public DataSource getMySqlDataSource() {
        return this.mySqlDataSource;
    }
}

