package org.dronda.lib.ktor.client.plugins

import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpClientPlugin
import io.ktor.client.plugins.HttpSend
import io.ktor.client.plugins.plugin
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.HttpRequestPipeline
import io.ktor.client.request.header
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.util.AttributeKey
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.NonCancellable
import kotlinx.coroutines.withContext

public class BearerAuthPlugin private constructor(private val config: Config) {
    public class Config {
        public var authorizeRequest: () -> String? = { "" }
        public var refreshToken: (suspend () -> Unit)? = null
    }

    private val refreshDeferred = atomic<CompletableDeferred<Unit>?>(null)

    /**
     * If we're currently refreshing, wait for it to finish and return updated token.
     *
     * Otherwise, check if the token is expired. If it is, try to refresh and then grab the new token. If it isn't
     * just return the token.
     */
    internal suspend fun authorizeRequest(): String? {
        var token: String?
        if (refreshDeferred.value?.await() == null) {
            token = config.authorizeRequest()
            if (token != null && decodeClaims(token)?.isExpired == true) {
                refreshToken()
                token = config.authorizeRequest()
            }
        } else {
            token = config.authorizeRequest()
        }

        return token
    }
    internal suspend fun refreshToken(): Unit? {
        if (config.refreshToken == null) return null
        var deferred: CompletableDeferred<Unit>?
        while (true) {
            deferred = refreshDeferred.value
            val newValue = deferred ?: CompletableDeferred()
            if (refreshDeferred.compareAndSet(deferred, newValue)) break
        }

        if (deferred != null) {
            return deferred.await()
        }
        withContext(NonCancellable) {
            config.refreshToken!!.invoke()
            refreshDeferred.value!!.complete(Unit)
            refreshDeferred.value = null
        }
        return Unit
    }

    public companion object Plugin : HttpClientPlugin<Config, BearerAuthPlugin> {
        override fun prepare(block: Config.() -> Unit): BearerAuthPlugin {
            val config = Config().apply { block() }
            return BearerAuthPlugin(config)
        }


        override fun install(plugin: BearerAuthPlugin, scope: HttpClient) {
            scope.requestPipeline.intercept(HttpRequestPipeline.State) {
                context.setBearerAuthorizationHeader(plugin.authorizeRequest())
            }

            scope.plugin(HttpSend).intercept { context ->
                val call = execute(context)

                if (call.response.status != HttpStatusCode.Unauthorized) return@intercept call
                if (call.request.attributes.contains(AuthCircuitBreaker)) return@intercept call

                plugin.refreshToken() ?: return@intercept call

                val request = HttpRequestBuilder()
                request.takeFrom(context)
                context.setBearerAuthorizationHeader(plugin.authorizeRequest())
                request.attributes.put(AuthCircuitBreaker, Unit)

                return@intercept execute(request)
            }
        }

        override val key: AttributeKey<BearerAuthPlugin> = AttributeKey("AuthPlugin")
        private val AuthCircuitBreaker: AttributeKey<Unit> = AttributeKey("auth-circuit-breaker")
    }

}

private fun HttpRequestBuilder.setBearerAuthorizationHeader(token: String?) {
    token?.let {
        header(HttpHeaders.Authorization, "Bearer $token")
    }
}

