package ai.passio.passiosdk.passiofood.tflite

import ai.passio.passiosdk.core.tflite.TFLiteRecognizer
import ai.passio.passiosdk.core.utils.DimenUtil
import ai.passio.passiosdk.passiofood.ObjectDetectionCandidate
import android.graphics.Bitmap
import android.graphics.RectF
import android.util.Size

// Only return this many results.
private const val NUM_DETECTIONS = 5

internal class TFLiteObjectDetector(
    inputSize: Size
) : TFLiteRecognizer<List<ObjectDetectionCandidate>>(
    true,
    inputSize,
    listOf()
) {

    private val outputLocations: Array<Array<FloatArray>> =
        DimenUtil.init3DFloatArray(1, NUM_DETECTIONS, 4)
    private val outputClasses: Array<FloatArray> = DimenUtil.init2DFloatArray(1, NUM_DETECTIONS)
    private val outputScores: Array<FloatArray> = DimenUtil.init2DFloatArray(1, NUM_DETECTIONS)
    private val numDetections: FloatArray = FloatArray(1)

    override fun recognizeImage(bitmap: Bitmap, args: List<Any>?): List<ObjectDetectionCandidate> {
        bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

        imgData.rewind()
        for (i in 0 until inputSize.width) {
            for (j in 0 until inputSize.height) {
                val pixelValue = intValues[i * inputSize.width + j]
                if (isQuantized) {
                    imgData.put((pixelValue shr 16 and 0xFF).toByte())
                    imgData.put((pixelValue shr 8 and 0xFF).toByte())
                    imgData.put((pixelValue and 0xFF).toByte())
                } else {
                    imgData.putFloat(((pixelValue shr 16 and 0xFF) - mean) / std)
                    imgData.putFloat(((pixelValue shr 8 and 0xFF) - mean) / std)
                    imgData.putFloat(((pixelValue and 0xFF) - mean) / std)
                }
            }
        }

        val inputArray: Array<Any> = arrayOf(imgData)
        val outputMap = mapOf(
            0 to outputLocations,
            1 to outputClasses,
            2 to outputScores,
            3 to numDetections
        )

        tfLite.runForMultipleInputsOutputs(inputArray, outputMap)

        val numberObjectsDetected = numDetections[0].toInt()
        val recognitions = ArrayList<ObjectDetectionCandidate>(numberObjectsDetected)
        for (i in 0 until numberObjectsDetected) {
            val detectionRect = RectF(
                outputLocations[0][i][1] * inputSize.width,
                outputLocations[0][i][0] * inputSize.height,
                outputLocations[0][i][3] * inputSize.width,
                outputLocations[0][i][2] * inputSize.height
            )

//            val label = labels[(outputClasses[0][i] + labelOffset).toInt()]
//            if (label == PassioRecognizer.BKG_PASSIO_ID) {
//                continue
//            }

            recognitions.add(
                ObjectDetectionCandidate(
                    "",
                    outputScores[0][i],
                    detectionRect
                )
            )
        }

        return recognitions
    }

}