package org.somda.dsl.rendering.jaxb

import jakarta.xml.bind.JAXBContext
import jakarta.xml.bind.JAXBException
import jakarta.xml.bind.Marshaller
import org.glassfish.jaxb.runtime.marshaller.NamespacePrefixMapper
import java.io.ByteArrayInputStream
import java.io.IOException
import java.io.OutputStream
import java.net.URL
import java.nio.charset.StandardCharsets
import javax.xml.XMLConstants
import javax.xml.parsers.DocumentBuilderFactory
import javax.xml.transform.stream.StreamSource
import javax.xml.validation.Schema
import javax.xml.validation.SchemaFactory

internal class DslJaxbMarshalling(
    contextPackages: List<JaxbContext>,
    schemaPaths: List<XmlSchemaPath>,
    namespaceMappings: Map<Namespace, NamespacePrefix>
) {
    private val namespacePrefixMapper: NamespacePrefixMapper
    private val jaxbContext: JAXBContext
    private var schema: Schema

    init {
        // Append internal mappings (xsi prefix)
        val namespaceMappingsExtended = namespaceMappings.toMutableMap().apply {
            put(Namespace("http://www.w3.org/2001/XMLSchema-instance"), NamespacePrefix("xsi"))
        }
        namespacePrefixMapper = NamespacePrefixMapperConverter.convert(namespaceMappingsExtended)

        val joinedContextPackages = contextPackages.joinToString(":") { it.value }

        try {
            jaxbContext = JAXBContext.newInstance(joinedContextPackages)
        } catch (e: JAXBException) {
            throw RuntimeException("JAXB context could not be created for '$joinedContextPackages'")
        }

        schema = generateTopLevelSchema(schemaPaths)
    }

    fun marshal(objectToMarshal: Any, outputStream: OutputStream) {
        val marshaller = jaxbContext.createMarshaller()
        marshaller.setProperty(JAXB_MARSHALLER_PROPERTY_KEY, namespacePrefixMapper)
        marshaller.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true)
        marshaller.schema = schema
        marshaller.marshal(objectToMarshal, outputStream)
    }

    private fun generateTopLevelSchema(schemaPaths: List<XmlSchemaPath>): Schema {
        val topLevelSchemaBeginning =
            "<xsd:schema xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\" elementFormDefault=\"qualified\">"
        val importPattern = "<xsd:import namespace=\"%s\" schemaLocation=\"%s\"/>"
        val topLevelSchemaEnd = "</xsd:schema>"

        val stringBuilder = StringBuilder()
        stringBuilder.append(topLevelSchemaBeginning)
        for (path in schemaPaths) {
            val classLoader = javaClass.classLoader
            val schemaUrl = classLoader.getResource(path.value) ?: throw IOException(
                "Could not find schema for resource while loading in ${DslJaxbMarshalling::class.java.simpleName}: $path",
            )
            val targetNamespace = resolveTargetNamespace(schemaUrl)
            stringBuilder.append(String.format(importPattern, targetNamespace, schemaUrl))
        }
        stringBuilder.append(topLevelSchemaEnd)
        // we *do* need external schema processing here (we build our schema with multiple xsd:import directives),
        // which causes an XXE warning by Sonarlint; but the schemas imported are fully under our control,
        // read from local files in our resource directory, and thus can be trusted.
        val schemaFactory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI)
        return schemaFactory.newSchema(
            StreamSource(
                ByteArrayInputStream(
                    stringBuilder.toString()
                        .toByteArray(StandardCharsets.UTF_8)
                )
            )
        )
    }

    private fun resolveTargetNamespace(url: URL): String {
        url.openStream().use { inputStream ->
            val factory = DocumentBuilderFactory.newInstance()
            // #218 prevent XXE attacks
            factory.setFeature(SAX_FEATURE_EXTERNAL_GENERAL_ENTITIES, false)
            factory.setFeature(SAX_FEATURE_EXTERNAL_PARAMETER_ENTITIES, false)
            val builder = factory.newDocumentBuilder()
            val document = builder.parse(inputStream)
            return document.documentElement.getAttribute("targetNamespace")
        }
    }

    private companion object {
        const val SAX_FEATURE_EXTERNAL_GENERAL_ENTITIES = "http://xml.org/sax/features/external-general-entities"
        const val SAX_FEATURE_EXTERNAL_PARAMETER_ENTITIES = "http://xml.org/sax/features/external-parameter-entities"
        const val JAXB_MARSHALLER_PROPERTY_KEY: String = "org.glassfish.jaxb.namespacePrefixMapper"
    }
}
