package ai.passio.passiosdk.passiofood.tflite

import ai.passio.passiosdk.core.bayes.BayesFilter
import ai.passio.passiosdk.core.tflite.TFLiteRecognizer
import ai.passio.passiosdk.core.utils.DimenUtil
import ai.passio.passiosdk.passiofood.ClassificationCandidate
import ai.passio.passiosdk.passiofood.PassioSDK
import ai.passio.passiosdk.passiofood.recognition.filter.KalmanFilter
import android.graphics.Bitmap
import android.util.Log
import android.util.Size
import kotlin.math.ln

private const val HNN_XNNPACK_MEMT = 100 * 1024 * 1024L

internal class TFLiteHNNKNNDetector(
    isImageNormalized: Boolean,
    inputSize: Size,
    labels: List<String>
) : TFLiteRecognizer<List<ClassificationCandidate>>(
    isImageNormalized,
    inputSize,
    labels,
    HNN_XNNPACK_MEMT
) {

    private val hnnResults: Array<FloatArray> by lazy {
        DimenUtil.init2DFloatArray(1, labels.size)
    }
    private val knnResults: Array<FloatArray> by lazy {
        DimenUtil.init2DFloatArray(1, labels.size)
    }
    private val bayesFilter: BayesFilter by lazy {
        BayesFilter(labels.size)
    }
    private val kalmanFilter = KalmanFilter(labels.size)

    override fun recognizeImage(bitmap: Bitmap, args: List<Any>?): List<ClassificationCandidate> {
        bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

        imgData.rewind()
        for (i in 0 until inputSize.width) {
            for (j in 0 until inputSize.height) {
                val pixelValue = intValues[i * inputSize.width + j]
                if (isQuantized) {
                    imgData.put((pixelValue shr 16 and 0xFF).toByte())
                    imgData.put((pixelValue shr 8 and 0xFF).toByte())
                    imgData.put((pixelValue and 0xFF).toByte())
                } else {
                    imgData.putFloat(((pixelValue shr 16 and 0xFF) - mean) / std)
                    imgData.putFloat(((pixelValue shr 8 and 0xFF) - mean) / std)
                    imgData.putFloat(((pixelValue and 0xFF) - mean) / std)
                }
            }
        }

        val inputArray: Array<Any> = arrayOf(imgData)
        val outputMap = mapOf(
            0 to knnResults,
            1 to hnnResults
        )

        tfLite.runForMultipleInputsOutputs(inputArray, outputMap)

//        var maxIndex = -1
//        var maxConf = -1f
//        val top5 = hnnResults[0].mapIndexed { index, fl -> index to fl }.sortedByDescending { it.second }.subList(0, 3)
//        top5.forEachIndexed { index, value ->
//            Log.d("HHHH", "Index[$index] -> ${labels[value.first]}, conf: ${value.second} ")
//        }

        val shouldRunFilter = args?.get(0) as? Boolean
        return if (shouldRunFilter != null && shouldRunFilter) {
            kalmanFilter.step(hnnResults[0])
            val filterResult = kalmanFilter.top(3, 0.0005f)
            if (filterResult.first < -ln(0.5)) {  // `first` is really `divergence`.
                filterResult.second.map { ClassificationCandidate(labels[it.index], it.confidence) }
            } else {
                listOf(ClassificationCandidate(PassioSDK.BKG_PASSIO_ID, 1.0f))
            }
        } else {
            val topPair = hnnResults[0].mapIndexed { index, conf -> index to conf }
                .maxByOrNull { it.second }!!
            listOf(ClassificationCandidate(labels[topPair.first], hnnResults[0][topPair.first]))
        }
    }

}