package one.zoop.sdk.scanner.utility


import one.zoop.sdk.scanner.utils.EncryptionUtil
import android.app.Application
import android.content.Context
import android.graphics.Bitmap
import android.graphics.PointF
import android.net.Uri
import android.util.Log
import android.util.Size
import androidx.camera.core.ImageAnalysis
import androidx.camera.core.ImageProxy
import androidx.navigation.NavController
import io.sentry.Sentry
import one.zoop.sdk.scanner.ui.core.components.crop_handler.DrawView
import one.zoop.sdk.scanner.model.ScanMode
import one.zoop.sdk.scanner.ndk.NdkConnector
import one.zoop.sdk.scanner.utils.ScannerConfigManager
import one.zoop.sdk.scanner.viewmodel.CameraViewModel

import one.zoop.sdk.scanner.viewmodel.CornerPointStatus

import org.tensorflow.lite.DataType
import org.tensorflow.lite.Delegate
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.nnapi.NnApiDelegate
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.Rot90Op
import java.io.FileInputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import java.util.concurrent.Executors
import kotlin.math.abs

//Utility class for Tensorflow lite
class TfliteHelper(private val application: Application) {
    companion object {
        private val TAG = TfliteHelper::class.java.simpleName
        private const val ACCURACY_THRESHOLD = 0.5f
        private const val NO_OF_THREADS = 4
        private const val ENC_MODEL_FILE_NAME = "enc_model.enc"
        private const val MODEL_FILE_NAME = "keypoint_320_256_float16_V5_1_output.tflite"
        private const val NOT_DETECTED_THRESHOLD = 10000//ms
        private const val DETECTED_THRESHOLD = 2500//ms
        private fun getDetectedThreshold(inferenceTime: Long, detectedTime: Int): Int {
            return (detectedTime / inferenceTime).toInt()
        }

        private fun getNotDetectedThreshold(inferenceTime: Int, notDetectedTime: Int): Int {
            return (notDetectedTime / inferenceTime).toInt()
//            return when {
//                inferenceTime <= 150 -> 35 // 4500ms
//                inferenceTime <= 200 -> 25 // 5000ms
//                inferenceTime <= 250 -> 20 // 5000ms
//                inferenceTime <= 500 -> 15 // 7500ms
//                else -> NOT_DETECTED_THRESHOLD
//            }
        }
    }


    private fun getTfliteDelegate(): Delegate {
        return NnApiDelegate()
    }

    private val tfInputSize: Size
    private fun loadModelFile(): MappedByteBuffer {
        val fileDescriptor = application.assets.openFd(MODEL_FILE_NAME)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

    private fun loadEncModel(): MappedByteBuffer {
        val inputStream = application.assets.open(ENC_MODEL_FILE_NAME)
        val modelPass = NdkConnector.getModelPass()
        return EncryptionUtil().decryptStreamAsByteBuffer(
            modelPass,
            inputStream
        )!!
    }

    private var interpreter: Interpreter = Interpreter(
        loadEncModel(),
        Interpreter.Options().setNumThreads(NO_OF_THREADS)
    )

    init {
        val inputIndex = 0
        val inputShape = interpreter.getInputTensor(inputIndex).shape()
        Log.d(TAG, inputShape.toString())
        tfInputSize = Size(inputShape[2], inputShape[1]) // Order of axis is: {1, height, width, 3}
    }

    //clean up
    fun closeTflite() {
        interpreter.close()
    }

    /** Utility Function **/
    fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
        val inputShape =
            intArrayOf(1, bitmap.height, bitmap.width, 3) // 1 x height x width x 3 (RGB channels)
        val inputSize = inputShape[1] * inputShape[2] * inputShape[3]
        val byteBuffer = ByteBuffer.allocateDirect(4 * inputSize)
        byteBuffer.order(ByteOrder.nativeOrder())

        val pixels = IntArray(bitmap.width * bitmap.height)
        bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

        for (pixel in pixels) {
            val r = (pixel shr 16 and 0xFF).toFloat()
            val g = (pixel shr 8 and 0xFF).toFloat()
            val b = (pixel and 0xFF).toFloat()

            byteBuffer.putFloat(r)
            byteBuffer.putFloat(g)
            byteBuffer.putFloat(b)
        }
        byteBuffer.rewind()
        return byteBuffer
    }

    private fun copyPointsArray(points: Array<PointF>): Array<PointF> {
        val copiedArray = Array(points.size) { PointF(0f, 0f) }
        for (i in points.indices) {
            copiedArray[i].x = points[i].x
            copiedArray[i].y = points[i].y
        }
        return copiedArray
    }

    private fun approxEquals(
        currentCorners: PointF,
        other: PointF,
        threshold: Float = 0.01f
    ): Boolean {
        return abs(currentCorners.x - other.x) <= threshold && abs(currentCorners.y - other.y) <= threshold
    }

    private fun areArraysEqualWithThreshold(
        arr1: Array<PointF>, arr2: Array<PointF>, threshold: Float
    ): Boolean {
        if (arr1.size != arr2.size) {
            return false
        }
        for (i in arr1.indices) {
            if (!approxEquals(arr1[i], arr2[i], threshold)) {
                return false
            }
        }
        return true
    }

    /*          Analyzer      */
    //THIS IS SYNC ,IT WILL NOT ANALYZE NEW IMAGE UNITAL AND UNLESS PREV FRAME INFERENCE IS COMPLETED
    @Suppress("UNCHECKED_CAST")
    //TRY CATCH will take care of it
    inner class TfLiteAnalyzer(
        private var drawView: DrawView?,
        private val cameraPreviewWidth: Int,
        private val cameraPreviewHeight: Int,
        private val mAutoCaptureScreenState: Boolean = true,
        private val autoCaptureTimeOut: Boolean = false,
        private val viewModel: CameraViewModel,
        private val context: Context,
        private val screenHeight: Double,
        private val screenWidth: Double,
        private var navController: NavController
    ) :
        ImageAnalysis.Analyzer {
        //time states
        private var remDetectedTime = DETECTED_THRESHOLD
        private var remNotDetectedTime = NOT_DETECTED_THRESHOLD

        //analyzer states
        private var isProcessFrames: Boolean = true //internal state check to manage frame
        private val isProcessFramesLock = Any()
        override fun analyze(image: ImageProxy) {
            if (drawView == null)
                return
            Log.d(
                TAG,
                "${image.height} ${image.width} ${image.format} ${image.imageInfo.rotationDegrees}"
            )
            synchronized(isProcessFramesLock) {
                if (!isProcessFrames || !mAutoCaptureScreenState || autoCaptureTimeOut) {
                    image.close()
                    return
                }
            }
            try {
                //processing
                val bitmap = image.toBitmap()
                //image analysis is rotate by 90 in all devices (if found issues need to reconsider)
                val start = System.currentTimeMillis()
                val inferenceData = runInference(bitmap, 90)
                val end = System.currentTimeMillis()
                val inferenceTime = (end - start).toInt()
                val corners = inferenceData["corner_points"] as Array<PointF>
                val score = inferenceData["conf_score"] as Float
                Log.d(
                    "inferenceCorner", corners.toString()
                )
                synchronized(isProcessFramesLock) {
                    if (isProcessFrames && mAutoCaptureScreenState && !autoCaptureTimeOut) {
                        //recalculate corners : fit to screen size
                        for (corner in corners) {
                            corner.x = corner.x.times(cameraPreviewWidth)
                            corner.y = corner.y.times(cameraPreviewHeight)
                        }
//            if (isProcessFrames) {
                        //draw point
                        //async -> will block thread until completes from ui thread -> to counter use extra var isProcessFrames
                        drawView!!.post {
                            if (score <= 1 && score > 0) {
                                remDetectedTime = DETECTED_THRESHOLD
                                remNotDetectedTime -= inferenceTime
                                if (remNotDetectedTime <= 0
                                ) {
                                    Log.d(TAG, "NOT_FOUND  ${corners.contentToString()}")
//                                    isProcessFrames = false
                                    viewModel.changeAutoCapturePromptState(CornerPointStatus.NOT_FOUND)
                                    //  binding.callback?.setAutoCapturePrompt(CameraFragment.CornerPointStatus.NOT_FOUND)
//                                    binding.viewModel!!.analyticsLogger.logEvent(
//                                        AnalyticsConst.autoCaptureTimeout, mutableMapOf(
//                                            "time_out_duration" to "15",
//                                            "error_value" to "Corners not found",
//                                            "confScore" to score,
//                                        ), this::class.java.simpleName
//                                    )
//                                    binding.callback?.stopAutoCaptureFunctionality()
                                    return@post
                                } else {
                                    //TODO : add blinking effect
                                    Log.d(TAG, "LOW_CONF_SCORE  ${corners.contentToString()}")
                                    viewModel.changeAutoCapturePromptState(CornerPointStatus.LOW_CONF_SCORE)
//                                    binding.callback?.setAutoCapturePrompt(CameraFragment.CornerPointStatus.LOW_CONF_SCORE)
                                }

                            } else if (score > 1) {
                                //detection started
                                viewModel.changeAutoCapturePromptState(CornerPointStatus.CAPTURING)

//                                binding.callback?.setAutoCapturePrompt(CameraFragment.CornerPointStatus.DETECTED)
                                //stabilized corner
                                if (
                                    areArraysEqualWithThreshold(
                                        corners,
                                        drawView!!.getCurrentDrawViewPoints(),
                                        50F
                                    )
                                ) {
                                    Log.d(TAG, "DETECTED ${corners.contentToString()}")

                                    remNotDetectedTime = NOT_DETECTED_THRESHOLD
                                    remDetectedTime -= inferenceTime
                                    if (remDetectedTime <= 0
                                    ) {
                                        viewModel.isAutoCaptureTimeout = true
                                        val mImageCapture = viewModel.getImageCapture()
                                        val noramlizedCornerPoints: ArrayList<ArrayList<Double>> =
                                            arrayListOf()
                                        isProcessFrames = false
                                        for (item: PointF in corners) {
                                            val xCoordinates =
                                                ((item.x / cameraPreviewWidth) * mImageCapture.resolutionInfo!!.resolution.height.toDouble())
                                            val yCoordinates =
                                                ((item.y / cameraPreviewHeight) * mImageCapture.resolutionInfo!!.resolution.height.toDouble())
                                            noramlizedCornerPoints.add(
                                                arrayListOf(
                                                    xCoordinates,
                                                    yCoordinates
                                                )
                                            )
                                        }
                                        for (corner in corners) {
                                            corner.x = corner.x.div(cameraPreviewWidth)
                                            corner.y = corner.y.div(cameraPreviewHeight)
                                        }
                                        viewModel.takePhoto(
                                            filenameFormat = "yyyy-MM-dd-HH-mm-ss-SSS",
                                            outputDirectory = viewModel.getOutputDirectory(context), //,
                                            onImageCaptured = { capturedUri ->
                                                viewModel.changeAutoCapturePromptState(
                                                    CornerPointStatus.DETECTED
                                                )

                                                val config = ScannerConfigManager.getConfig()

                                                // Trigger animation if needed for gallery images
                                                if (viewModel.isRetake.value != true && ((config?.mode == ScanMode.BATCH && !config.isIdCard) || (config?.isIdCard == true && (viewModel.capturedImages.value?.size
                                                        ?: 0) < 2))
                                                ) {
                                                    viewModel.startCaptureAnimation(capturedUri as Uri)
                                                }

                                                // Case 1: Handle retake mode
                                                if (viewModel.isRetake.value == true) {
                                                    viewModel.setRetakeValue(false)
                                                    navController.navigate("previewScreen")
                                                }
                                                // Case 2: Handle ID card scanning
                                                else if (config?.isIdCard == true) {
                                                    // Navigate if back image is not required OR if back image is required and we have both images
                                                    if (!config.isBackImageRequired || (viewModel.capturedImages.value?.size == 2)) {
                                                        viewModel.setRetakeValue(false)
                                                        navController.navigate("previewScreen")
                                                    }
                                                    // Otherwise no action needed (waiting for second image capture)
                                                }
                                                // Case 3: Handle non-ID card scanning
                                                else if (config?.mode == ScanMode.SINGLE || config?.mode == ScanMode.OCR) {
                                                    viewModel.setRetakeValue(false)
                                                    navController.navigate("previewScreen")
                                                }
                                                // Case 4: All other scenarios (e.g., batch mode)
                                                // No navigation needed - continue capturing images

                                                val restartAnalyzer =
                                                    (viewModel.isRetake.value == true || ScannerConfigManager.getConfig()?.mode != ScanMode.SINGLE || (ScannerConfigManager.getConfig()?.isIdCard == true && ScannerConfigManager.getConfig()?.isBackImageRequired == true && viewModel.capturedImages.value?.size == 1))
                                                if (restartAnalyzer) {
                                                    viewModel.isAutoCaptureTimeout = false
                                                    viewModel.getImageAnalyzer().setAnalyzer(
                                                        Executors.newSingleThreadExecutor(),
                                                        viewModel.tfliteHelper.TfLiteAnalyzer(
                                                            drawView,
                                                            cameraPreviewWidth,
                                                            cameraPreviewHeight,
                                                            viewModel = viewModel,
                                                            screenWidth = screenWidth,
                                                            screenHeight = screenHeight,
                                                            context = context,
                                                            navController = navController
                                                        )
                                                    )
                                                }
                                            },
                                            screenHeight = screenHeight,
                                            screenWidth = screenWidth,
                                            onError = {

                                            },
                                            corners = corners, context = context,
                                        )
                                        Log.d("inference", "taking photo")
                                    }
                                } //not getting stabilized corner
                                else {
                                    remDetectedTime = DETECTED_THRESHOLD
                                    remNotDetectedTime -= inferenceTime
                                }
                            } else {
                                Log.d(TAG, "FINDING ${corners.contentToString()}")

                                viewModel.changeAutoCapturePromptState(CornerPointStatus.FINDING)
                                remDetectedTime = DETECTED_THRESHOLD
                                remNotDetectedTime -= inferenceTime
                            }
                            //send copy otherwise as array are reference passed next thread running concurrently will replace it and mess up will happen
                            if (!viewModel.isAutoCaptureTimeout) {
                                drawView!!.setCornerPointsAnimated(copyPointsArray(corners))

                            }
                        }
                    }
                }
            } catch (e: Exception) {
                Sentry.captureException(e);
                Log.e(TAG, "Something went wrong ${e.message}")
            }
            image.close()
        }
    }


    /* inference utility for processing image frame with specific imageRotation */
    //TRY CATCH WILL TAKE CARE OF UNCHECK CAST
    fun runInference(
        image: Bitmap, rotationAngle: Int
    ): Map<String, Any> {
        try {
            val start = System.currentTimeMillis()
            Log.d("inputTensor", "w:${image.width} h:${image.height}")
            val tensorImage2 = TensorImage(DataType.FLOAT32).also { it.load(image) }

            val tfImageProcessor = ImageProcessor.Builder().add(Rot90Op(-rotationAngle / 90)).add(
                ResizeOp(
                    tfInputSize.height, tfInputSize.width, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR
                )
            ).build()
//            .add(NormalizeOp(1f, 255.0f))
            val tensorImage = tfImageProcessor.process(tensorImage2)
            Log.d("inputTensor", "w:${tensorImage.width} h:${tensorImage.height}")
//        val byteBuffer:ByteBuffer =  convertBitmapToByteBuffer(tensorImage.bitmap)
            val inputTensor = arrayOf(tensorImage.buffer)
            Log.d("inputTensor", tensorImage.buffer.capacity().toString())
            val outputTensor = initOutputMap(interpreter)
            interpreter.runForMultipleInputsOutputs(inputTensor, outputTensor)
//        TensorB
            // Process the output tensor to get the corner points
            val score = outputTensor[4] as Array<Array<FloatArray>> //v5
//        val score = outputTensor[5] as Array<Array<FloatArray>> //V3
            var maxSum = Float.NEGATIVE_INFINITY
            var maxIndex = mutableListOf<Int>()
            val xCords = mutableListOf<Float>()
            val yCords = mutableListOf<Float>()

            for (i in score.indices) {
                for (j in score[i].indices) {
                    var sum = 0f
                    for (k in score[i][j].indices) {
                        sum += score[i][j][k]
                    }
                    if (sum > maxSum) {
                        maxSum = sum
                        maxIndex = mutableListOf(j)
                    } else if (sum == maxSum) {
                        maxIndex.add(j)
                    }
                }
            }

//        val keyPoints = outputTensor[4] as Array<Array<Array<FloatArray>>> //v3
            val keyPoints = outputTensor[1] as Array<Array<Array<FloatArray>>> //v5
            for (i in keyPoints[0][maxIndex[0]].indices) {
                xCords.add(keyPoints[0][maxIndex[0]][i][1])
                yCords.add(keyPoints[0][maxIndex[0]][i][0])
            }

            // MAKE SURE PREVIEW AND INPUT IMAGE SIZES ARE THE SAME OTHERWISE CORNER POINTS WILL MISMATCH
            val combinedList: Array<PointF> = if (maxSum > 0.8) {
                arrayOf(
                    PointF(xCords[0], yCords[0]),
                    PointF(xCords[1], yCords[1]),
                    PointF(xCords[2], yCords[2]),
                    PointF(xCords[3], yCords[3])
                )
            } else {
                var count = 0
                for (i in score[0][maxIndex[0]].indices) {
                    val checkThreshForAllPoints = score[0][maxIndex[0]][i]
                    if (checkThreshForAllPoints > ACCURACY_THRESHOLD) {
                        count++
                    }
                }

                if (count > 2) {
                    arrayOf(
                        PointF(xCords[0], yCords[0]),
                        PointF(xCords[1], yCords[1]),
                        PointF(xCords[2], yCords[2]),
                        PointF(xCords[3], yCords[3])
                    )
                } else {
                    arrayOf()

                }
            }
            val end = System.currentTimeMillis()
            val inferenceTime = (end - start)
            Log.d(TAG, "Inference Time $inferenceTime")
            return mapOf(
                "corner_points" to combinedList,
                "conf_score" to maxSum,
                "inference_time" to inferenceTime
            )
        } catch (e: Exception) {
            Sentry.captureException(e);
            Log.e(TAG, "Something went wrong : ${e.message}")
            return mapOf(
                "corner_points" to arrayOf<PointF>(),
                "conf_score" to 0.0F,
                "inference_time" to 0.0F
            )
        }
    }

    /**
     * Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate.
     */
    private fun initOutputMap(interpreter: Interpreter): HashMap<Int, Any> {
        val outputMap = HashMap<Int, Any>()
        //model v3
//         outputMap[0] = List.filled(1 * 4 * 4, 0).reshape([1, 4, 4]);
        // outputMap[1] = List.filled(1 * 4, 0).reshape([1, 4]);
        // outputMap[2] = List.filled(1 * 4, 0).reshape([1, 4]);
        // outputMap[3] = List.filled(1, 0).reshape([1]);
        // outputMap[4] = List.filled(1 * 4 * 4 * 2, 0).reshape([1, 4, 4, 2]);
        // outputMap[5] = List.filled(1 * 4 * 4, 0).reshape([1, 4, 4]);
        // 1 * 10 * 4 contains heatmaps
        val boxes = interpreter.getOutputTensor(0).shape()
//        val output0 = TensorBuffer.createFixedSize(intArrayOf(boxes[0], boxes[1],boxes[2]), DataType.FLOAT32)

        outputMap[0] = Array(boxes[0]) {
            FloatArray(boxes[1])
        }

        // 1 * 10 contains offsets
        val classId = interpreter.getOutputTensor(1).shape()
        outputMap[1] = Array(classId[0]) {
            Array(classId[1]) {
                Array(classId[2]) {
                    FloatArray(classId[3])
                }
            }
        }

        // 1 * 10  contains forward displacements
        val boxScore = interpreter.getOutputTensor(2).shape()
        outputMap[2] = Array(boxScore[0]) {
            Array(boxScore[1]) {
                FloatArray(boxScore[2])
            }
        }

        // 1 contains backward displacements
        val count = interpreter.getOutputTensor(3).shape()
        outputMap[3] = Array(count[0]) {
            FloatArray(count[1])
        }

        //1 * 10 * 4 *2
        val keypoint = interpreter.getOutputTensor(4).shape()
        outputMap[4] = Array(keypoint[0]) {
            Array(keypoint[1]) {
                FloatArray(keypoint[2])
            }
        }

        //1 * 10 * 4
        val keypointScore = interpreter.getOutputTensor(5).shape()
        outputMap[5] =
            FloatArray(keypointScore[0])
        return outputMap
    }

}