package ai.passio.passiosdk.core.utils

import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.ColorMatrix
import android.graphics.ColorMatrixColorFilter
import android.graphics.Matrix
import android.graphics.Paint
import android.graphics.Rect
import android.graphics.RectF
import androidx.camera.core.ImageProxy
import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min

private const val kMaxChannelValue = 262143

internal object ImageUtils {

    fun convertYUV420ToARGB8888(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        width: Int,
        height: Int,
        yRowStride: Int,
        uvRowStride: Int,
        uvPixelStride: Int,
        out: IntArray
    ) {
        var yp = 0
        for (j in 0 until height) {
            val pY = yRowStride * j
            val pUV = uvRowStride * (j shr 1)

            for (i in 0 until width) {
                val uv_offset = pUV + (i shr 1) * uvPixelStride

                try {
                    out[yp++] = YUV2RGB(
                        0xff and yData[pY + i].toInt(),
                        0xff and uData[uv_offset].toInt(),
                        0xff and vData[uv_offset].toInt()
                    )
                } catch (e: IndexOutOfBoundsException) {
                    e.printStackTrace()
                    throw e
                }

            }
        }
    }

    internal fun convertYUV420ToARGB8888CPP(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        width: Int,
        height: Int,
        yRowStride: Int,
        uvRowStride: Int,
        uvPixelStride: Int,
        output: IntArray
    ) {
        nativeConvertYUV420ToARGB8888(
            yData,
            uData,
            vData,
            width,
            height,
            output,
            yRowStride,
            uvRowStride,
            uvPixelStride,
            false
        )
    }

    private external fun nativeConvertYUV420ToARGB8888(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        width: Int,
        height: Int,
        output: IntArray,
        yRowStride: Int,
        uvRowStride: Int,
        uvPixelStride: Int,
        halfSize: Boolean
    )

    private fun YUV2RGB(yValue: Int, uValue: Int, vValue: Int): Int {
        var y = yValue
        var u = uValue
        var v = vValue
        // Adjust and check YUV values
        y = if (y - 16 < 0) 0 else y - 16
        u -= 128
        v -= 128

        // This is the floating point equivalent. We do the conversion in integer
        // because some Android devices do not have floating point in hardware.
        // nR = (int)(1.164 * nY + 2.018 * nU);
        // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
        // nB = (int)(1.164 * nY + 1.596 * nV);
        val y1192 = 1192 * y
        var r = y1192 + 1634 * v
        var g = y1192 - 833 * v - 400 * u
        var b = y1192 + 2066 * u

        // Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ]
        r = if (r > kMaxChannelValue) kMaxChannelValue else if (r < 0) 0 else r
        g = if (g > kMaxChannelValue) kMaxChannelValue else if (g < 0) 0 else g
        b = if (b > kMaxChannelValue) kMaxChannelValue else if (b < 0) 0 else b

        return -0x1000000 or (r shl 6 and 0xff0000) or (g shr 2 and 0xff00) or (b shr 10 and 0xff)
    }

    internal fun getTransformationMatrix(
        srcWidth: Int,
        srcHeight: Int,
        dstWidth: Int,
        dstHeight: Int,
        applyRotation: Int,
        maintainAspectRatio: Boolean
    ): Matrix {
        val matrix = Matrix()

        if (applyRotation != 0) {
            if (applyRotation % 90 != 0) {
                // Don't apply non-standard rotation
            } else {
                matrix.postTranslate(-srcWidth / 2f, -srcHeight / 2f)
                matrix.postRotate(applyRotation.toFloat())
            }
        }

        val transpose = (abs(applyRotation) + 90) % 180 == 0

        val inWidth = if (transpose) srcHeight else srcWidth
        val inHeight = if (transpose) srcWidth else srcHeight

        if (inWidth != dstWidth || inHeight != dstHeight) {
            val scaleFactorX = dstWidth / inWidth.toFloat()
            val scaleFactorY = dstHeight / inHeight.toFloat()

            if (maintainAspectRatio) {
                val scaleFactor = max(scaleFactorX, scaleFactorY)
                matrix.postScale(scaleFactor, scaleFactor)
            } else {
                matrix.postScale(scaleFactorX, scaleFactorY)
            }
        }

        if (applyRotation != 0) {
            matrix.postTranslate(dstWidth / 2f, dstHeight / 2f)
        }

        return matrix
    }

    internal fun normalizeRectF(rect: RectF, width: Int, height: Int): Rect {
        val left = max(0f, rect.left)
        val top = max(0f, rect.top)
        val right = min(width.toFloat(), rect.right)
        val bottom = min(height.toFloat(), rect.bottom)
        return Rect(left.toInt(), top.toInt(), right.toInt(), bottom.toInt())
    }

    internal fun imageProxyToYUV(
        planes: Array<ImageProxy.PlaneProxy>,
        yuvBytes: Array<ByteArray?>
    ) {
        planes.indices.forEach { i ->
            val buffer = planes[i].buffer
            buffer.rewind()
            if (yuvBytes[i] == null || yuvBytes[i]!!.size != buffer.capacity()) {
                yuvBytes[i] = ByteArray(buffer.capacity())
            }
            buffer.get(yuvBytes[i]!!)
        }
    }

    internal fun changeBitmapContrast(bitmap: Bitmap, contrast: Float, brightness: Float): Bitmap {
        val cm = ColorMatrix(
            floatArrayOf(
                contrast, 0f, 0f, 0f, brightness,
                0f, contrast, 0f, 0f, brightness,
                0f, 0f, contrast, 0f, brightness,
                0f, 0f, 0f, 1f, 0f
            )
        )

        val ret = Bitmap.createBitmap(bitmap.width, bitmap.height, bitmap.config)
        val canvas = Canvas(ret)

        val paint = Paint()
        paint.colorFilter = ColorMatrixColorFilter(cm)
        canvas.drawBitmap(bitmap, 0f, 0f, paint)

        return ret
    }

    internal fun getNV21(
        inputWidth: Int,
        inputHeight: Int,
        scaled: Bitmap
    ): Pair<ByteArray, ByteArray> {
        val argb = IntArray(inputWidth * inputHeight)
        scaled.getPixels(argb, 0, inputWidth, 0, 0, inputWidth, inputHeight)
        val yPlane = ByteArray(inputWidth * inputHeight)
        var uvSize = inputWidth * inputHeight / 2
        if (uvSize % 2 == 1) {
            uvSize--
        }
        val uvPlane = ByteArray(uvSize)
        encodeYUV420SP(yPlane, uvPlane, argb, inputWidth, inputHeight)
        return yPlane to uvPlane
    }

    private fun encodeYUV420SP(
        yuv420spYPlane: ByteArray,
        yuv420spUVPlane: ByteArray,
        argb: IntArray,
        width: Int,
        height: Int
    ) {
        var yIndex = 0
        var uvIndex = 0
        var a: Int
        var R: Int
        var G: Int
        var B: Int
        var Y: Int
        var U: Int
        var V: Int
        var index = 0

        for (j in 0 until height) {
            for (i in 0 until width) {
                a = argb[index] and -0x1000000 shr 24 // a is not used obviously
                R = argb[index] and 0xff0000 shr 16
                G = argb[index] and 0xff00 shr 8
                B = argb[index] and 0xff shr 0

                // well known RGB to YUV algorithm
                Y = (66 * R + 129 * G + 25 * B + 128 shr 8) + 16
                U = (-38 * R - 74 * G + 112 * B + 128 shr 8) + 128
                V = (112 * R - 94 * G - 18 * B + 128 shr 8) + 128

                // NV21 has a plane of Y and interleaved planes of VU each sampled by a factor of 2
                //    meaning for every 4 Y pixels there are 1 V and 1 U.  Note the sampling is every other
                //    pixel AND every other scanline.
                yuv420spYPlane[yIndex++] = (if (Y < 0) 0 else if (Y > 255) 255 else Y).toByte()
                if (j % 2 == 0 && index % 2 == 0) {
                    yuv420spUVPlane[uvIndex++] =
                        (if (V < 0) 0 else if (V > 255) 255 else V).toByte()
                    yuv420spUVPlane[uvIndex++] =
                        (if (U < 0) 0 else if (U > 255) 255 else U).toByte()
                }
                index++
            }
        }
    }

    internal fun cropYUV420(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        yRowStride: Int,
        uvRowStride: Int,
        uvPixelStride: Int,
        srcWidth: Int,
        srcHeight: Int,
        startX: Int,
        startY: Int,
        dstWidth: Int,
        dstHeight: Int,
        outY: ByteArray,
        outU: ByteArray,
        outV: ByteArray
    ) {
        nativeCropYUV420(
            yData, uData, vData,
            yRowStride, uvRowStride, uvPixelStride,
            srcWidth, srcHeight, startX, startY,
            dstWidth, dstHeight,
            outY, outU, outV
        )
    }

    private external fun nativeCropYUV420(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        yRowStride: Int,
        uvRowStride: Int,
        uvPixelStride: Int,
        srcWidth: Int,
        srcHeight: Int,
        startX: Int,
        startY: Int,
        dstWidth: Int,
        dstHeight: Int,
        outY: ByteArray,
        outU: ByteArray,
        outV: ByteArray
    )

    internal fun rotateYUV420(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        width: Int,
        height: Int,
        rotation: Int,
        outY: ByteArray,
        outU: ByteArray,
        outV: ByteArray
    ) {
        nativeRotateYUV420(yData, uData, vData, width, height, rotation, outY, outU, outV)
    }

    private external fun nativeRotateYUV420(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        width: Int,
        height: Int,
        rotation: Int,
        outY: ByteArray,
        outU: ByteArray,
        outV: ByteArray
    )

    internal fun findBestBoundingBoxIn(srcRect: RectF, srcSize: Int, dstSize: Int): Rect {
        val halfSize = dstSize / 2
        var left = 0
        var right = 0
        var top = 0
        var bottom = 0

        // Fit width
        if (srcRect.centerX() - halfSize < 0) {
            left = 0
            right = dstSize
        } else if (srcRect.centerX() + halfSize > srcSize) {
            right = srcSize
            left = right - dstSize
        } else {
            left = srcRect.centerX().toInt() - halfSize
            right = left + dstSize
        }

        // Fit height
        if (srcRect.centerY() - halfSize < 0) {
            top = 0
            bottom = dstSize
        } else if (srcRect.centerY() + halfSize > srcSize) {
            bottom = srcSize
            top = bottom - dstSize
        } else {
            top = srcRect.centerY().toInt() - halfSize
            bottom = top + dstSize
        }

        return Rect(left, top, right, bottom)
    }

    internal fun Bitmap.scaleToFit(targetWidth: Int, targetHeight: Int): Pair<Bitmap, Matrix> {
        val matrix = Matrix()
        val targetBitmap = Bitmap.createBitmap(targetWidth, targetHeight, Bitmap.Config.ARGB_8888)
        val paint = Paint().apply {
            isFilterBitmap = true
        }
        val canvas = Canvas(targetBitmap)

        if (width >= height) {
            val ratio = targetWidth.toFloat() / width
            val resultHeight = height * ratio
            val heightPadding = targetHeight - resultHeight
            matrix.postScale(ratio, ratio)
            matrix.postTranslate(0f, heightPadding / 2)
            canvas.drawBitmap(this, matrix, paint)
        } else {
            val ratio = targetHeight.toFloat() / height
            val resultWidth = width * ratio
            val widthPadding = targetWidth - resultWidth
            matrix.postScale(ratio, ratio)
            matrix.postTranslate(widthPadding / 2f, 0f)
            canvas.drawBitmap(this, matrix, paint)
        }

        val inverse = Matrix()
        matrix.invert(inverse)
        return targetBitmap to inverse
    }

    internal fun scaleNV21(
        yData: ByteArray,
        uData: ByteArray,
        vData: ByteArray,
        width: Int,
        height: Int,
        outY: ByteArray,
        outU: ByteArray,
        outV: ByteArray,
        outWidth: Int,
        outHeight: Int
    ) {
    }

    private external fun nativeNV21BilinearScale(
        yuv: ByteArray,
        width: Int,
        height: Int,
        resultYuv:
        ByteArray,
        resultWidth: Int,
        resultHeight: Int
    )

    fun getCorrectionMatrix(
        imageProxy: ImageProxy,
        previewWidth: Int,
        previewHeight: Int
    ): Matrix {
        val cropRect = imageProxy.cropRect
        val rotationDegrees = imageProxy.imageInfo.rotationDegrees
        val matrix = Matrix()

        // A float array of the source vertices (crop rect) in clockwise order.
        val source = floatArrayOf(
            cropRect.left.toFloat(),
            cropRect.top.toFloat(),
            cropRect.right.toFloat(),
            cropRect.top.toFloat(),
            cropRect.right.toFloat(),
            cropRect.bottom.toFloat(),
            cropRect.left.toFloat(),
            cropRect.bottom.toFloat()
        )

        // A float array of the destination vertices in clockwise order.
        val destination = floatArrayOf(
            0f,
            0f,
            previewWidth.toFloat(),
            0f,
            previewWidth.toFloat(),
            previewHeight.toFloat(),
            0f,
            previewHeight.toFloat()
        )

        // The destination vertexes need to be shifted based on rotation degrees. The
        // rotation degree represents the clockwise rotation needed to correct the image.

        // Each vertex is represented by 2 float numbers in the vertices array.
        val vertexSize = 2
        // The destination needs to be shifted 1 vertex for every 90° rotation.
        val shiftOffset = rotationDegrees / 90 * vertexSize;
        val tempArray = destination.clone()
        for (toIndex in source.indices) {
            val fromIndex = (toIndex + shiftOffset) % source.size
            destination[toIndex] = tempArray[fromIndex]
        }
        matrix.setPolyToPoly(source, 0, destination, 0, 4)
        return matrix
    }
}