package ai.cheq.sst.android.core.models

import ai.cheq.sst.android.core.Config
import ai.cheq.sst.android.core.ContextProvider
import ai.cheq.sst.android.core.Utils
import ai.cheq.sst.android.core.exceptions.ConflictingModelException
import com.fasterxml.jackson.core.JsonGenerator
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.core.JsonToken
import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.SerializerProvider
import com.fasterxml.jackson.databind.annotation.JsonSerialize
import com.fasterxml.jackson.databind.deser.std.StdDeserializer
import com.fasterxml.jackson.databind.ser.std.StdSerializer

/**
 * The Models class provides a way to manage and access models that collect and expose data.
 * Custom models can be added to the collection as long as they do not conflict with base models.
 * The following base models are always included and cannot be overridden:
 *   * [AppModel]
 *   * [LibraryModel]
 *
 * @sample samples.core.Models.Usage.defaultModels
 * @sample samples.core.Models.Usage.requiredModels
 * @sample samples.core.Models.Usage.customModels
 */
class Models private constructor() {
    private val baseModels = ModelSet(AppModel(), LibraryModel(), DeviceModel.default())
    private val modelSet = ModelSet()
    private var finalized = false

    companion object {
        /**
         * Creates a new  models builder with the default models.
         * The following models are included and cannot be overridden:
         *   * [AppModel]
         *   * [LibraryModel]
         *
         * The following models are included and cannot be overridden:
         *   * [DeviceModel]
         *
         * @return A new instance of the Models.Builder class with the default models.
         */
        @JvmStatic
        @JvmName("defaultModels")
        fun default(): Models {
            return Models().default()
        }

        /**
         * Creates a new models builder with no default models.
         * The following models are included and cannot be overridden:
         *   * [AppModel]
         *   * [LibraryModel]
         */
        @JvmStatic
        @JvmName("requiredModels")
        fun required(): Models {
            return Models().required()
        }
    }

    /**
     * Adds custom models to the builder.
     *
     * @param T The type of the model to add.
     * @param model The model to add.
     * @return The builder instance for chaining.
     *
     * @throws ConflictingModelException If a conflicting model is attempted to be added.
     */
    @Throws(ConflictingModelException::class)
    fun <T : Model<*>> add(vararg model: T): Models {
        check(!finalized) { "Cannot add models after initialization." }
        model.forEach {
            errorIfInvalidModelOverride(it)
            modelSet.add(it)
        }
        return this
    }

    /**
     * Retrieves a model of the specified class.
     *
     * @param T The type of the model to retrieve.
     * @param clazz The class, which is type [T], of the model to retrieve.
     * @return The model of the specified class, or `null` if not found.
     */
    fun <T : Model<*>> get(clazz: Class<T>): T? =
        modelSet.get(clazz)

    /**
     * Retrieves a model of the specified class.
     *
     * @param T The type of the model to retrieve.
     * @return The model of the specified type, or `null` if not found.
     */
    inline fun <reified T : Model<*>> get(): T? = get(T::class.java)

    /**
     * Retrieves all models in the collection.
     *
     * @return A list of all models.
     */
    fun all(): List<Model<*>> = modelSet.enabled()

    internal suspend fun collect(config: Config, contextProvider: ContextProvider): Data {
        val data = Data()
        val modelContext = getModelContext(config, contextProvider)
        modelSet.all().forEach {
            val modelData = it.get(modelContext)
            if (modelData != null) {
                data.add(it.identifier.key, it.dataClass, it.enabled, modelData)
            }
        }
        return data
    }

    @JvmName("collectOne")
    internal suspend inline fun <reified T : Model<D>, reified D : Model.Data> collect(
        config: Config, contextProvider: ContextProvider
    ): D? {
        val modelContext = getModelContext(config, contextProvider)
        return get<T>()?.get(modelContext)
    }

    internal fun initialize(config: Config, contextProvider: ContextProvider) {
        val modelContext = getModelContext(config, contextProvider)
        modelSet.all().forEach {
            it.initialize(modelContext)
        }
    }

    internal fun finalize(): Models {
        finalized = true
        return this
    }

    private fun default(): Models {
        val models = baseModels.all().mapNotNull {
            if (it.modelType() != ModelType.STANDARD) {
                it
            } else {
                null
            }
        }.toTypedArray()
        return add(*models)
    }

    private fun required(): Models {
        val models = baseModels.all().mapNotNull {
            if (it.modelType() == ModelType.REQUIRED) {
                it
            } else if (it.modelType() == ModelType.DEFAULT) {
                it.enabled = false
                it
            } else {
                null
            }
        }.toTypedArray()
        return add(*models)
    }

    private fun getModelContext(config: Config, contextProvider: ContextProvider): ModelContext {
        return ModelContext(
            contextProvider,
            config.modelScope,
            config.modelDispatcher,
            ModelIdentifiers(modelSet.filter { it.modelType() == ModelType.REQUIRED }.map { it.identifier },
                modelSet.filter { it.modelType() != ModelType.REQUIRED && it.enabled }.map { it.identifier }),
            config.log
        )
    }

    private fun errorIfInvalidModelOverride(model: Model<*>) {
        val modelByKey = modelSet.get(model.identifier.key)?.let { modelSet.get(it) }
        val modelByType = modelSet.get(model::class.java)
        if (modelByKey != null && modelByType != null && modelByKey == modelByType && modelByKey.modelType() == ModelType.REQUIRED) {
            throw ConflictingModelException("Cannot override required model ${model.identifier.key}:${model::class.java.simpleName}.")
        }
        if (modelByKey != null && modelByKey.modelType() == ModelType.REQUIRED) {
            throwConflictingModelException(
                "Cannot override required model", modelByKey, modelByKey::class.java, model, model::class.java
            )
        }
        if (modelByType != null && modelByType.modelType() == ModelType.REQUIRED) {
            throwConflictingModelException(
                "Cannot override required model", modelByType, modelByType::class.java, model, model::class.java
            )
        }
        if ((modelByKey != null).xor( modelByType != null)) {
            if (modelByKey != null) {
                if (modelByKey.modelType() == ModelType.DEFAULT) {
                    throwConflictingModelException(
                        "Cannot override default model", modelByKey, modelByKey::class.java, model, model::class.java
                    )
                } else {
                    throwConflictingModelException(
                        "Cannot override model", modelByKey, modelByKey::class.java, model, model::class.java)
                }
            }
            if (modelByType != null) {
                if (modelByType.modelType() == ModelType.STANDARD) {
                    throwConflictingModelException(
                        "Cannot override model",
                        modelByType,
                        modelByType::class.java,
                        model,
                        model::class.java
                    )
                } else {
                    throwConflictingModelException(
                        "Cannot override model",
                        modelByType,
                        modelByType::class.java,
                        model,
                        model::class.java
                    )
                }
            }
        }
    }

    private fun throwConflictingModelException(
        message: String,
        existingModel: Model<*>,
        existingClass: Class<out Model<*>>,
        newModel: Model<*>,
        newClass: Class<out Model<*>>
    ) {
        throw ConflictingModelException(
            "$message ${existingModel.identifier.key}:${
                existingClass.simpleName
            } with ${newModel.identifier.key}:${newClass.simpleName}"
        )
    }

    @JsonSerialize(using = Data.Serializer::class)
    internal class Data {
        private val models: HashMap<Class<out Model.Data>, Element> = HashMap()

        internal fun <T : Model.Data> add(key: String, clazz: Class<T>, enabled: Boolean, data: Model.Data) {
            models[clazz] = Element(key, enabled, data)
        }

        fun <T : Model.Data> get(clazz: Class<T>, includeAll: Boolean = false): T? {
            val model = models[clazz] ?: return null
            if (!includeAll && !model.enabled) {
                return null
            }
            return clazz.cast(model.data)
        }

        inline fun <reified T : Model.Data> get(includeAll: Boolean = false): T? = get(T::class.java, includeAll)

        internal fun all() = models.values.map { it.data }

        fun toMap() = models.values.filter { it.enabled }.associate {
            it.key to it.data.toMap()
        }

        class Serializer : StdSerializer<Data>(Data::class.java, false) {
            override fun serialize(
                value: Data?, gen: JsonGenerator?, provider: SerializerProvider?
            ) {
                gen!!.writeStartObject()
                value!!.models.filter { it.value.enabled }.toSortedMap { l, r -> l.name.compareTo(r.name) }.forEach {
                    gen.writeFieldName(it.value.key)
                    gen.writeObject(it.value.data)
                }
                gen.writeEndObject()
            }
        }

        class Deserializer(private vararg val models: Model<*>) :
            StdDeserializer<Data>(Data::class.java) {
            override fun deserialize(p: JsonParser?, ctxt: DeserializationContext?): Data {
                val data = Data()
                if (p != null && p.codec != null && ctxt != null) {
                    if (p.currentToken == JsonToken.FIELD_NAME) {
                        p.nextToken()
                    }
                    val node: JsonNode? = p.codec.readTree(p)
                    if (node != null) {
                        models.forEach {
                            val modelData = Utils.jsonMapper.readValue(
                                node.get(it.identifier.key).traverse(p.codec), it.dataClass
                            )
                            data.add(it.identifier.key, it.dataClass, it.enabled, modelData)
                        }
                    }
                    p.nextToken()
                }
                return data
            }
        }

        private data class Element(val key: String, val enabled: Boolean, val data: Model.Data)
    }

    private class ModelSet(vararg models: Model<*>) {
        private val keyedModels: HashMap<String, Model<*>> = HashMap()
        private val models: HashMap<Class<out Model<*>>, Model<*>> = HashMap()

        init {
            models.forEach {
                add(it)
            }
        }

        fun get(key: String): Class<out Model<*>>? = keyedModels[key]?.let { return it::class.java }

        fun <T : Model<*>> get(clazz: Class<T>): T? {
            val model = models[clazz]
            if (model == null) {
                return model
            }
            return clazz.cast(model)
        }

        fun all() = models.values.toList()

        fun enabled() = models.values.filter { it.enabled }

        inline fun filter(predicate: (Model<*>) -> Boolean): List<Model<*>> {
            return models.values.filter(predicate)
        }

        fun <T : Model<*>> add(model: T) {
            keyedModels[model.identifier.key] = model
            models[model::class.java] = model
        }
    }
}