package ai.passio.passiosdk.core.tflite

import android.graphics.RectF
import java.util.PriorityQueue

private const val NMS_THRESHOLD = 0.3f

internal object NonMaxSuppression {

    fun calculate(
        labels: List<String>,
        results: List<Recognition>
    ): List<Recognition> {
        val nmsList = arrayListOf<Recognition>()

        labels.forEach { label ->
            // 1. Find max confidence per class
            val pq = PriorityQueue<Recognition>(50) { o1, o2 ->
                o2.confidence.compareTo(o1.confidence)
            }

            results.forEach { result ->
                if (result.classId == label) {
                    pq.add(result)
                }
            }

            // 2. Do non max suppression
            while (pq.size > 0) {
                val detections = pq.toTypedArray()
                val max = detections[0]
                nmsList.add(max)
                pq.clear()

                detections.forEach { detection ->
                    val conf = boxIOU(max.boundingBox, detection.boundingBox)
                    if (conf < NMS_THRESHOLD) {
                        pq.add(detection)
                    }
                }
            }
        }

        return nmsList
    }

    private fun boxIOU(a: RectF, b: RectF): Float {
        return boxIntersection(a, b) / boxUnion(a, b)
    }

    private fun boxIntersection(a: RectF, b: RectF): Float {
        val w = overlap(
            (a.left + a.right) / 2, a.right - a.left,
            (b.left + b.right) / 2, b.right - b.left
        )
        val h = overlap(
            (a.top + a.bottom) / 2, a.bottom - a.top,
            (b.top + b.bottom) / 2, b.bottom - b.top
        )
        return if (w < 0 || h < 0) 0f else w * h
    }

    private fun boxUnion(a: RectF, b: RectF): Float {
        val i = boxIntersection(a, b)
        return (a.right - a.left) * (a.bottom - a.top) + (b.right - b.left) * (b.bottom - b.top) - i
    }

    private fun overlap(x1: Float, w1: Float, x2: Float, w2: Float): Float {
        val l1 = x1 - w1 / 2
        val l2 = x2 - w2 / 2
        val left = if (l1 > l2) l1 else l2
        val r1 = x1 + w1 / 2
        val r2 = x2 + w2 / 2
        val right = if (r1 < r2) r1 else r2
        return right - left
    }

}
