package pro.vdshb.dbcleaner

import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource
import pro.vdshb.dbcleaner.postgresutil.*
import java.sql.Connection
import java.util.*
import javax.sql.DataSource

/**
 * Postgresql specific database cleaner
 */
object PostgresqlDatabaseCleaner {

    private val dataSource: DataSource

    init {
        val dbCleanerPropertiesStream = this::class.java.classLoader.getResourceAsStream("db-cleaner.properties")
        val properties = Properties()
        properties.load(dbCleanerPropertiesStream)
        val config = HikariConfig()
        config.jdbcUrl = properties["url"] as String
        config.driverClassName = properties["driver"] as String
        config.username = properties["username"] as String
        config.password = properties["password"] as String
        dataSource = HikariDataSource(config)
        dbCleanerPropertiesStream.close()
    }

    fun cleanPostgresDb(
        schemaToClean: String = "public",
        onlyTables: List<String> = emptyList(),
        excludeTables: List<String> = emptyList()
    ) {
        dataSource.connection.use { connection ->
            val allTables = findAllTables(schemaToClean, connection)
            var tablesToClean: List<String>
            if (onlyTables.isNotEmpty()) {
                tablesToClean = onlyTables
                throwExceptionIfIncorrectTableInList(onlyTables, allTables)
            } else {
                tablesToClean = allTables
            }
            tablesToClean = tablesToClean.filter { !excludeTables.contains(it) }
            //todo: log-warning not existed tables in 'excludeTables'
            cleanTables(schemaToClean, tablesToClean, connection)
        }
    }

    private fun throwExceptionIfIncorrectTableInList(tablesToClean: List<String>, allTables: List<String>) {
        tablesToClean.forEach {
            if (!allTables.contains(it)) {
                throw IllegalArgumentException("There is no table '$it' in database")
            }
        }
    }

    //====================== Clean tables section ======================

    private fun cleanTables(schemaToClean: String, tablesToClean: List<String>, connection: Connection) {
        connection.autoCommit = false
        try {
            // all constraints to deferrable
            val constraints = findAllTableConstraints(schemaToClean, tablesToClean, connection)
            setConstraintsDeferred(constraints, connection)
            val triggers = findAllConstraintTriggers(constraints, connection)
            setTriggersDeferred(triggers, connection)

            // clean DB
            tablesToClean.forEach { cleanTable(connection, it) }

            // restore constraints to previous state
            restoreConstraints(constraints, connection)
            restoreTriggers(triggers, connection)
            connection.commit()
        } catch (ex: Exception) {
            connection.rollback()
            throw ex
        } finally {
            connection.autoCommit = true
        }
    }

    private fun setConstraintsDeferred(constraints: List<PostgresConstraint>, connection: Connection) {
        if (constraints.isEmpty()) {
            return
        }
        val constraintsOids = constraints
            .map { it.oid }
            .joinToString(separator = ",")
        connection.createStatement()
            .execute(
                """UPDATE pg_catalog.pg_constraint SET condeferrable = true, condeferred = true
                                 WHERE oid IN ($constraintsOids)"""
            )
    }

    private fun setTriggersDeferred(triggers: List<PostgresTrigger>, connection: Connection) {
        if (triggers.isEmpty()) {
            return
        }
        val triggerOids = triggers
            .map { it.oid }
            .joinToString(separator = ",")

        connection.createStatement()
            .execute(
                """UPDATE pg_catalog.pg_trigger SET tgdeferrable = true, tginitdeferred = true
                                 WHERE oid IN ($triggerOids)"""
            )
    }

    private fun restoreConstraints(constraints: List<PostgresConstraint>, connection: Connection) {
        val preparedStatement = connection
            .prepareStatement("UPDATE pg_catalog.pg_constraint SET condeferrable = ?, condeferred = ? WHERE oid = ?")
        constraints.forEach { constraint ->
            preparedStatement.setBoolean(1, constraint.condeferrable)
            preparedStatement.setBoolean(2, constraint.condeferred)
            preparedStatement.setLong(3, constraint.oid)
            preparedStatement.execute()
        }
    }

    private fun restoreTriggers(triggers: List<PostgresTrigger>, connection: Connection) {
        val preparedStatement = connection
            .prepareStatement("UPDATE pg_catalog.pg_trigger SET tgdeferrable = ?, tginitdeferred = ? WHERE oid = ?")
        triggers.forEach { trigger ->
            preparedStatement.setBoolean(1, trigger.tgdeferrable)
            preparedStatement.setBoolean(2, trigger.tginitdeferred)
            preparedStatement.setLong(3, trigger.oid)
            preparedStatement.execute()
        }
    }

    @Suppress("SqlWithoutWhere")
    private fun cleanTable(connection: Connection, tableName: String) {
        val cleanQuery = "DELETE FROM $tableName"
        connection.createStatement().execute(cleanQuery)
    }

}
