package ai.passio.passiosdk.passiofood.recognition

import ai.passio.passiosdk.core.config.SDKFileType
import ai.passio.passiosdk.core.utils.PassioLog
import ai.passio.passiosdk.passiofood.config.passio_FoodNet_HNN
import ai.passio.passiosdk.passiofood.config.passio_FoodNet_HNN_ENv2
import ai.passio.passiosdk.passiofood.config.passio_FoodNet_SSD_HNN
import ai.passio.passiosdk.passiofood.config.passio_FoodNet_Yolo
import ai.passio.passiosdk.passiofood.config.passio_nutrition_HNN
import ai.passio.passiosdk.passiofood.file.PassioFoodFileManager
import ai.passio.passiosdk.passiofood.metadata.MetadataManager
import ai.passio.passiosdk.passiofood.tflite.TFLiteHNNKNNDetector
import ai.passio.passiosdk.passiofood.tflite.TFLiteObjectDetector
import ai.passio.passiosdk.passiofood.tflite.TFLiteYoloObjectDetector
import android.content.Context
import android.util.Size

internal class PassioAssetSecuredModelHolder(
    labelManager: MetadataManager,
    fileManager: PassioFoodFileManager,
) : PassioModelHolder(labelManager, fileManager) {

    override fun initializeModel(
        context: Context,
        fileType: SDKFileType,
        version: Int
    ): Boolean {
        when (fileType) {
            passio_FoodNet_SSD_HNN -> {
                val modelName = "${passio_FoodNet_SSD_HNN.name}.$version.${getExtension()}"

                PassioLog.i(
                    this::class.java.simpleName,
                    "Initializing $modelName"
                )

                val inputSize = Size(
                    PassioRecognizer.TF_OD_API_INPUT_SIZE,
                    PassioRecognizer.TF_OD_API_INPUT_SIZE
                )

                objectDetector = assetCatch(modelName) {
                    TFLiteObjectDetector(inputSize).apply {
                        createFromSecuredAssetModel(
                            context,
                            modelName,
                        )
                    }
                }

                if (objectDetector != null) {
                    PassioLog.i(
                        this::class.java.simpleName,
                        "$modelName init success"
                    )
                } else {
                    PassioLog.e(
                        this::class.java.simpleName,
                        "$modelName init failed"
                    )
                }

                return objectDetector != null
            }

            passio_FoodNet_HNN -> {
                val modelName = "${passio_FoodNet_HNN.name}.$version.${getExtension()}"
                PassioLog.i(
                    this::class.java.simpleName,
                    "Initializing $modelName"
                )
                val labels = labelManager.getVisualPassioIDs()
                val inputSize = Size(
                    PassioRecognizer.TF_HNN_API_INPUT_SIZE,
                    PassioRecognizer.TF_HNN_API_INPUT_SIZE
                )

                hnnDetector = assetCatch(modelName) {
                    TFLiteHNNKNNDetector(true, inputSize, labels).apply {
                        createFromSecuredAssetModel(
                            context,
                            modelName,
                        )
                    }
                }

                if (hnnDetector != null) {
                    PassioLog.i(
                        this::class.java.simpleName,
                        "$modelName init success"
                    )
                } else {
                    PassioLog.e(
                        this::class.java.simpleName,
                        "$modelName init failed"
                    )
                }

                return hnnDetector != null
            }

            passio_nutrition_HNN -> {
                return true
            }

            passio_FoodNet_HNN_ENv2 -> {
                val modelName = "${passio_FoodNet_HNN_ENv2.name}.$version.${getExtension()}"
                PassioLog.i(
                    this::class.java.simpleName,
                    "Initializing $modelName"
                )
                val labels = labelManager.getVisualPassioIDs()
                val inputSize = Size(
                    PassioRecognizer.TF_HNN_API_INPUT_SIZE,
                    PassioRecognizer.TF_HNN_API_INPUT_SIZE
                )

                hnnDetector = assetCatch(modelName) {
                    TFLiteHNNKNNDetector(false, inputSize, labels).apply {
                        createFromSecuredAssetModel(
                            context,
                            modelName,
                        )
                    }
                }

                if (hnnDetector != null) {
                    PassioLog.i(
                        this::class.java.simpleName,
                        "$modelName init success"
                    )
                } else {
                    PassioLog.e(
                        this::class.java.simpleName,
                        "$modelName init failed"
                    )
                }

                return hnnDetector != null
            }

            passio_FoodNet_Yolo -> {
                val modelName = "${passio_FoodNet_Yolo.name}.$version.${getExtension()}"

                PassioLog.i(
                    this::class.java.simpleName,
                    "Initializing $modelName"
                )

                val inputSize = Size(
                    PassioRecognizer.TF_YOLO_INPUT_SIZE,
                    PassioRecognizer.TF_YOLO_INPUT_SIZE
                )

                objectDetector = assetCatch(modelName) {
                    TFLiteYoloObjectDetector(inputSize).apply {
                        createFromSecuredAssetModel(
                            context,
                            modelName,
                        )
                    }
                }

                if (objectDetector != null) {
                    PassioLog.i(
                        this::class.java.simpleName,
                        "$modelName init success"
                    )
                } else {
                    PassioLog.e(
                        this::class.java.simpleName,
                        "$modelName init failed"
                    )
                }

                return objectDetector != null
            }

            else -> throw IllegalArgumentException("No known file type: ${fileType.name}")
        }
    }

    override fun getExtension(): String = "passiosecure"
}