package ai.passio.passiosdk.core.tflite

import ai.passio.passiosdk.core.aes.CryptoHandler
import ai.passio.passiosdk.core.aes.CryptoIOBuffer
import ai.passio.passiosdk.core.file.PassioFileManager.Companion.getTempDir
import ai.passio.passiosdk.core.os.NativeUtils
import ai.passio.passiosdk.core.utils.FileUtil
import ai.passio.passiosdk.core.utils.MemoryUtil
import ai.passio.passiosdk.core.utils.PassioLog
import android.content.Context
import android.content.pm.FeatureInfo
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.util.Log
import android.util.Size
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.metadata.MetadataExtractor
import org.tensorflow.lite.support.metadata.schema.NormalizationOptions
import java.io.File
import java.io.FileInputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.FileChannel

private const val NORM_IMAGE_MEAN = 127.5f
private const val NORM_IMAGE_STD = 127.5f
private const val NONE_IMAGE_MEAN = 0f
private const val NONE_IMAGE_STD = 1f

internal abstract class TFLiteRecognizer<T>(
    isNormalized: Boolean,
    var inputSize: Size,
    protected var labels: List<String>,
    private var xnnPackThreshold: Long? = null,
    protected val isQuantized: Boolean = false
) {
    protected lateinit var tfLite: Interpreter
    protected val imgData: ByteBuffer
    protected val intValues: IntArray

    protected var std: Float = 0f
    protected var mean: Float = 0f

    companion object {
        private val NUM_CORES: Int by lazy {
            var availableCores = Runtime.getRuntime().availableProcessors()
            if (availableCores > 1) {
                availableCores /= 2
            }
            availableCores
        }
    }

    init {
        std = if (isNormalized) NORM_IMAGE_STD else NONE_IMAGE_STD
        mean = if (isNormalized) NORM_IMAGE_MEAN else NONE_IMAGE_MEAN

        val numBytesPerChannel = if (isQuantized) 1 else 4

        imgData =
            ByteBuffer.allocateDirect(1 * inputSize.width * inputSize.height * 3 * numBytesPerChannel)
        imgData.order(ByteOrder.nativeOrder())

        intValues = IntArray(inputSize.width * inputSize.height)
    }

    /**
     * Constructor for compressed asset model
     */
    internal fun createFromCompressedAssetModel(
        context: Context,
        zipName: String,
    ) {
        val inputStream = context.assets.open(zipName)
        val fileNameComponents = zipName.split(".")
        val fileName = fileNameComponents[0] + "." + fileNameComponents[1] + ".passio2"
        val rootPath = context.getTempDir()
        File(rootPath).mkdirs()
        val outPath = rootPath + File.separator + fileName
        CryptoHandler.instance.decrypt(
            inputStream,
            outPath,
            NativeUtils.instance.nativeGetKey1().toByteArray(),
            NativeUtils.instance.nativeGetKey1().toByteArray()
        )

        val passio2File = File(outPath)
        val modelBuffer = FileUtil.unzipToBuffer(passio2File)!!
        passio2File.delete()

        tfLite = Interpreter(modelBuffer, getDefaultTFLiteOptions())
    }

    /**
     * Constructor for compressed external model
     */
    internal fun createFromCompressedExternalModel(zipFile: File) {
        val modelBuffer = FileUtil.unzipToBuffer(zipFile)!!
        modelBuffer.rewind()
        tryMetadataRead(modelBuffer)

        tfLite = Interpreter(modelBuffer, getDefaultTFLiteOptions())
    }

    /**
     * Constructor for secured external model
     */
    internal fun createFromSecuredExternalModel(modelFile: File) {
        val inputStream = FileInputStream(modelFile)

        val size = CryptoHandler.instance.getOutputSize(
            inputStream.available(),
            NativeUtils.instance.nativeGetKey1().toByteArray(),
            NativeUtils.instance.nativeGetKey2().toByteArray()
        )

        val cryptoBuffer = CryptoIOBuffer(size)
        cryptoBuffer.write(inputStream)

        cryptoBuffer.outputBuffer.position(4)
        CryptoHandler.instance.decryptInMemory(
            cryptoBuffer.buffer,
            cryptoBuffer.outputBuffer,
            NativeUtils.instance.nativeGetKey1().toByteArray(),
            NativeUtils.instance.nativeGetKey2().toByteArray()
        )
        cryptoBuffer.buffer.rewind()
        tryMetadataRead(cryptoBuffer.buffer)

        tfLite = Interpreter(cryptoBuffer.buffer, getDefaultTFLiteOptions())
    }

    /**
     * Constructor for secured asset model
     */
    internal fun createFromSecuredAssetModel(
        context: Context,
        modelName: String,
    ) {
        val inputStream = context.assets.open(modelName)

        val size = CryptoHandler.instance.getOutputSize(
            inputStream.available(),
            NativeUtils.instance.nativeGetKey1().toByteArray(),
            NativeUtils.instance.nativeGetKey2().toByteArray()
        )

        val cryptoBuffer = CryptoIOBuffer(size)
        cryptoBuffer.write(inputStream)

        cryptoBuffer.outputBuffer.position(4)
        CryptoHandler.instance.decryptInMemory(
            cryptoBuffer.buffer,
            cryptoBuffer.outputBuffer,
            NativeUtils.instance.nativeGetKey1().toByteArray(),
            NativeUtils.instance.nativeGetKey2().toByteArray()
        )

        tfLite = Interpreter(cryptoBuffer.buffer, getDefaultTFLiteOptions())
    }

    fun createFromExternalModel(modelFile: File) {
        val fis = FileInputStream(modelFile)
        val buffer = fis.channel.map(FileChannel.MapMode.READ_ONLY, 0, modelFile.length())
        tfLite = Interpreter(buffer, getDefaultTFLiteOptions())
    }

    fun createFromAssetModel(
        context: Context,
        modelPath: String
    ) {
        val fileDescriptor = context.assets.openFd(modelPath)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val buffer = inputStream.channel.map(
            FileChannel.MapMode.READ_ONLY,
            fileDescriptor.startOffset,
            fileDescriptor.declaredLength
        )
        tfLite = Interpreter(buffer, getDefaultTFLiteOptions())
    }

    private fun getDefaultTFLiteOptions(): Interpreter.Options {
        return Interpreter.Options().apply {
            numThreads = NUM_CORES
            if (xnnPackThreshold == null) {
                // setUseXNNPACK(true)
                PassioLog.i(TFLiteRecognizer::class.java.simpleName, "Applied XNN Pack")
            } else if (MemoryUtil.getAvailableRuntimeMemory() > xnnPackThreshold!!) {
                // setUseXNNPACK(true)
                PassioLog.i(TFLiteRecognizer::class.java.simpleName, "Applied XNN Pack")
            }
        }
    }

    private fun getOpenGLVersion(context: Context): Pair<Int, Int> {
        val packageManager: PackageManager = context.packageManager
        val featureInfos = packageManager.systemAvailableFeatures
        if (featureInfos.isNotEmpty()) {
            for (featureInfo in featureInfos) {
                // Null feature name means this feature is the open gl es version feature.
                if (featureInfo.name == null) {
                    return if (featureInfo.reqGlEsVersion != FeatureInfo.GL_ES_VERSION_UNDEFINED) {
                        getMajorVersion(featureInfo.reqGlEsVersion) to getMinorVersion(featureInfo.reqGlEsVersion)
                    } else {
                        1 to 0
                    }
                }
            }
        }
        return 1 to 0
    }

    private fun getMajorVersion(glEsVersion: Int): Int {
        return glEsVersion and -0x10000 shr 16
    }

    private fun getMinorVersion(glEsVersion: Int): Int {
        return glEsVersion and 0xffff
    }

    private fun tryMetadataRead(modelBuffer: ByteBuffer): Boolean {
        val extractor = MetadataExtractor(modelBuffer)

        if (!extractor.hasMetadata() || !extractor.isMinimumParserVersionSatisfied) {
            return false
        }

        val inputTensorShape = extractor.getInputTensorShape(0)!!
        inputSize = Size(inputTensorShape[1], inputTensorShape[2])

        val normalizationOptions = extractor.getInputTensorMetadata(0)?.processUnits(0)
            ?.options(NormalizationOptions()) as? NormalizationOptions
        if (normalizationOptions != null) {
            std = normalizationOptions.std(0)
            mean = normalizationOptions.mean(0)
        }

        return true
    }

    abstract fun recognizeImage(bitmap: Bitmap, args: List<Any>? = null): T
}