package ai.passio.passiosdk.passiofood.tflite

import ai.passio.passiosdk.core.tflite.NonMaxSuppression
import ai.passio.passiosdk.core.tflite.Recognition
import ai.passio.passiosdk.core.tflite.TFLiteRecognizer
import ai.passio.passiosdk.passiofood.ObjectDetectionCandidate
import android.graphics.Bitmap
import android.graphics.RectF
import android.util.Log
import android.util.Size
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.pow

private const val MINIMUM_CONFIDENCE_TF_OD_API = 0.1f

internal class TFLiteYoloObjectDetector(
    defaultSize: Size
) : TFLiteRecognizer<List<ObjectDetectionCandidate>>(
    true,
    defaultSize,
    listOf("FD00001")
) {

    private val numBytesPerChannel = if (isQuantized) 1 else 4
    private val outputBox: Int by lazy {
        (((inputSize.width / 32).toDouble().pow(2.0) +
        (inputSize.width / 16).toDouble().pow(2.0) +
        (inputSize.width / 8).toDouble().pow(2.0)) * 3).toInt()
    }
    private val outData: ByteBuffer by lazy {
        ByteBuffer.allocateDirect(outputBox * (labels.size + 5) * numBytesPerChannel).apply {
            order(ByteOrder.nativeOrder())
        }
    }

    override fun recognizeImage(bitmap: Bitmap, args: List<Any>?): List<ObjectDetectionCandidate> {
        bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

        mean = 0f
        std = 255.0f

        imgData.rewind()
        for (i in 0 until inputSize.width) {
            for (j in 0 until inputSize.height) {
                val pixelValue = intValues[i * inputSize.width + j]
                if (isQuantized) {
                    imgData.put((pixelValue shr 16 and 0xFF).toByte())
                    imgData.put((pixelValue shr 8 and 0xFF).toByte())
                    imgData.put((pixelValue and 0xFF).toByte())
                } else {
                    imgData.putFloat(((pixelValue shr 16 and 0xFF) - mean) / std)
                    imgData.putFloat(((pixelValue shr 8 and 0xFF) - mean) / std)
                    imgData.putFloat(((pixelValue and 0xFF) - mean) / std)
                }
            }
        }

        val outputMap = mutableMapOf<Int, Any>()
        val inputArray = arrayOf(imgData)

        outData.rewind()
        outputMap[0] = outData

        tfLite.runForMultipleInputsOutputs(inputArray, outputMap)

        val byteBuffer = outputMap[0] as ByteBuffer
        byteBuffer.rewind()

        val detections = ArrayList<Recognition>()

        val out = Array(1) {
            Array(outputBox) {
                FloatArray(labels.size + 5)
            }
        }

        for (i in 0 until outputBox) {
            for (j in 0 until labels.size + 5) {
                out[0][i][j] = byteBuffer.float
            }
            // Denormalize xywh
            for (j in 0..3) {
                out[0][i][j] *= inputSize.width.toFloat()
            }
        }
        var counter = 0
        for (i in 0 until outputBox) {
            val confidence = out[0][i][4]
            if (confidence > MINIMUM_CONFIDENCE_TF_OD_API) {
                val xPos = out[0][i][0]
                val yPos = out[0][i][1]
                val w = out[0][i][2]
                val h = out[0][i][3]

                val rect = RectF(
                    0f.coerceAtLeast(xPos - w / 2),
                    0f.coerceAtLeast(yPos - h / 2),
                    (bitmap.width - 1).toFloat().coerceAtMost(xPos + w / 2),
                    (bitmap.height - 1).toFloat().coerceAtMost(yPos + h / 2)
                )
                detections.add(
                    Recognition(
                        labels[0],
                        confidence,
                        rect
                    )
                )
            }
        }

        val nmsResults = NonMaxSuppression.calculate(labels, detections)
        return nmsResults.map {
            ObjectDetectionCandidate(
                it.classId,
                it.confidence,
                it.boundingBox
            )
        }

    }

}