package no.nav.security.mock.oauth2.token

import com.nimbusds.jose.JOSEObjectType
import com.nimbusds.oauth2.sdk.GrantType
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.extensions.clientIdAsString
import no.nav.security.mock.oauth2.extensions.grantType
import no.nav.security.mock.oauth2.extensions.scopesWithoutOidcScopes
import no.nav.security.mock.oauth2.extensions.tokenExchangeGrantOrNull
import java.time.Duration
import java.util.UUID

interface OAuth2TokenCallback {
    fun issuerId(): String
    fun subject(tokenRequest: TokenRequest): String?
    fun typeHeader(tokenRequest: TokenRequest): String
    fun audience(tokenRequest: TokenRequest): List<String>
    fun addClaims(tokenRequest: TokenRequest): Map<String, Any>
    fun tokenExpiry(): Long
}

// TODO: for JwtBearerGrant and TokenExchange should be able to ovverride sub, make sub nullable and return some default
open class DefaultOAuth2TokenCallback @JvmOverloads constructor(
    private val issuerId: String = "default",
    private val subject: String = UUID.randomUUID().toString(),
    private val typeHeader: String = JOSEObjectType.JWT.type,
    // needs to be nullable in order to know if a list has explicitly been set, empty list should be a allowable value
    private val audience: List<String>? = null,
    private val claims: Map<String, Any> = emptyMap(),
    private val expiry: Long = 3600,
) : OAuth2TokenCallback {

    override fun issuerId(): String = issuerId

    override fun subject(tokenRequest: TokenRequest): String {
        return when (GrantType.CLIENT_CREDENTIALS) {
            tokenRequest.grantType() -> tokenRequest.clientIdAsString()
            else -> subject
        }
    }

    override fun typeHeader(tokenRequest: TokenRequest): String {
        return typeHeader
    }

    override fun audience(tokenRequest: TokenRequest): List<String> {
        val audienceParam = tokenRequest.tokenExchangeGrantOrNull()?.audience
        return when {
            audience != null -> audience
            audienceParam != null -> audienceParam
            tokenRequest.scope != null -> tokenRequest.scopesWithoutOidcScopes()
            else -> listOf("default")
        }
    }

    override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> =
        claims.toMutableMap().apply {
            putAll(
                mapOf(
                    "azp" to tokenRequest.clientIdAsString(),
                    "tid" to issuerId,
                ),
            )
        }

    override fun tokenExpiry(): Long = expiry
}

data class RequestMappingTokenCallback(
    val issuerId: String,
    val requestMappings: Set<RequestMapping>,
    val tokenExpiry: Long = Duration.ofHours(1).toSeconds(),
) : OAuth2TokenCallback {
    override fun issuerId(): String = issuerId

    override fun subject(tokenRequest: TokenRequest): String? =
        requestMappings.getClaimOrNull(tokenRequest, "sub")

    override fun typeHeader(tokenRequest: TokenRequest): String =
        requestMappings.getTypeHeader(tokenRequest)

    override fun audience(tokenRequest: TokenRequest): List<String> =
        requestMappings.getClaimOrNull(tokenRequest, "aud") ?: emptyList()

    override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> =
        requestMappings.getClaims(tokenRequest)

    override fun tokenExpiry(): Long = tokenExpiry
    private fun Set<RequestMapping>.getClaims(tokenRequest: TokenRequest): Map<String, Any> {
        val claims = firstOrNull { it.isMatch(tokenRequest) }?.claims ?: emptyMap()
        return if (tokenRequest.grantType() == GrantType.CLIENT_CREDENTIALS && claims["sub"] == "\${clientId}") {
            claims + ("sub" to tokenRequest.clientIdAsString())
        } else {
            claims
        }
    }

    private inline fun <reified T> Set<RequestMapping>.getClaimOrNull(tokenRequest: TokenRequest, key: String): T? =
        getClaims(tokenRequest)[key] as? T

    private fun Set<RequestMapping>.getTypeHeader(tokenRequest: TokenRequest) =
        firstOrNull { it.isMatch(tokenRequest) }?.typeHeader ?: JOSEObjectType.JWT.type
}

data class RequestMapping(
    private val requestParam: String,
    private val match: String = "*",
    val claims: Map<String, Any> = emptyMap(),
    val typeHeader: String = JOSEObjectType.JWT.type,
) {
    fun isMatch(tokenRequest: TokenRequest): Boolean =
        tokenRequest.toHTTPRequest().queryParameters[requestParam]?.any {
            if (match != "*") it == match else true
        } ?: false
}
