package ai.passio.passiosdk.core.ocr

import ai.passio.passiosdk.core.aes.CryptoHandler
import ai.passio.passiosdk.core.file.PassioFileManager.Companion.getTempDir
import ai.passio.passiosdk.core.utils.FileUtil
import ai.passio.passiosdk.core.utils.PassioLog
import android.content.Context
import android.util.Pair
import java.io.ByteArrayInputStream
import java.io.File
import java.io.IOException
import java.io.ObjectInput
import java.io.ObjectInputStream

private const val SEARCH_CONFIDENCE_THRESHOLD = 0.05f
private const val PACKAGED_FOOD_THRESHOLD = 0.35f

internal class KNNMatcher {

    private var nativePtr: Long = -1L
    private val alphaRegex: Regex = Regex("[^a-z]")
    private val whitespaceRegex = Regex("\\s+")

    constructor(
        context: Context,
        assetPath: String
    ) {
        val buffer = FileUtil.loadFileFromAssets(context.assets, assetPath)
        flatbufferDeserialize(buffer.array())
    }

    constructor(file: File) {
        val vectorizeBytes = when (file.extension) {
            "passio2" -> FileUtil.unzipToArray(file)!!
            "passio" -> FileUtil.loadFileFromSystem(file)
            else -> {
                throw IllegalStateException("No known ocr file extension: ${file.extension}")
            }
        }

        if (!javaDeserialize(vectorizeBytes)) {
            flatbufferDeserialize(vectorizeBytes)
        }
    }

    constructor(bytes: ByteArray) {
        if (!javaDeserialize(bytes)) {
            flatbufferDeserialize(bytes)
        }
    }

    constructor(
        vectorizeFile: File,
        key1: String,
        key2: String
    ) {
        val modelBytes = FileUtil.loadFileFromSystem(vectorizeFile)
        val decryptedBytes = CryptoHandler.instance.decrypt(
            modelBytes,
            key1.toByteArray(),
            key2.toByteArray()
        )

        if (!javaDeserialize(decryptedBytes)) {
            flatbufferDeserialize(decryptedBytes)
        }
    }

    constructor(
        context: Context,
        fileName: String,
        key1: String,
        key2: String
    ) {
        val ext = fileName.split(".").last()
        if (ext == "passiosecure") {
            val inputStreamVectorize = context.assets.open(fileName)
            val vectorizeBytes = ByteArray(inputStreamVectorize.available())
            inputStreamVectorize.read(vectorizeBytes)
            inputStreamVectorize.close()

            val decryptedVectorize = CryptoHandler.instance.decrypt(
                vectorizeBytes,
                key1.toByteArray(),
                key2.toByteArray()
            )

            initialize2(decryptedVectorize)
        } else if (ext == "passiosecure2") {
            val inputStream = context.assets.open(fileName)
            val fileNameComponents = fileName.split(".")
            val newFileName = fileNameComponents[0] + "." + fileNameComponents[1] + ".passio2"
            val rootPath = context.getTempDir()
            File(rootPath).mkdirs()

            val outPath = rootPath + File.separator + newFileName
            CryptoHandler.instance.decrypt(
                inputStream,
                outPath,
                key1.toByteArray(),
                key2.toByteArray()
            )
            val passio2File = File(outPath)
            val modelBytes = FileUtil.unzipToArray(passio2File)!!
            passio2File.delete()

            if (!javaDeserialize(modelBytes)) {
                flatbufferDeserialize(modelBytes)
            }
        } else {
            throw IllegalArgumentException("No known extension: $ext")
        }
    }

    private fun javaDeserialize(vectorizeBytes: ByteArray): Boolean {
        val bis = ByteArrayInputStream(vectorizeBytes)
        var objectInput: ObjectInput? = null
        try {
            objectInput = ObjectInputStream(bis)
            val model = objectInput.readObject() as JavaSerializer.OcrVectorizeModel
            val transformedVocab = ArrayList<android.util.Pair<String, Int>>(model.vocab.size)
            model.vocab.forEach { vocabItem ->
                transformedVocab.add(Pair(vocabItem.ngram, vocabItem.index))
            }
            val transformedMatrix = ArrayList<Array<Float>>(model.matrix.size)
            model.matrix.forEach { matrix ->
                transformedMatrix.add(arrayOf(matrix.r.toFloat(), matrix.c.toFloat(), matrix.v))
            }
            nativePtr =
                nativeInit(
                    model.upcs.size,
                    model.weights.size,
                    model.weights.toList(),
                    transformedVocab,
                    transformedMatrix,
                    model.upcs.toList()
                )
            return true
        } catch (ioe: IOException) {
            PassioLog.w(this::class.java.simpleName, ioe.message ?: "")
            return false
        } catch (cnfe: ClassNotFoundException) {
            PassioLog.w(this::class.java.simpleName, cnfe.message ?: "")
            return false
        } catch (e: Exception) {
            e.printStackTrace()
            return false
        } finally {
            try {
                objectInput?.close()
            } catch (ex: IOException) {
                PassioLog.w(this::class.java.simpleName, ex.message ?: "")
            }
        }
    }

    private fun flatbufferDeserialize(vectorizeBytes: ByteArray) {
        nativePtr = nativeInitFlatbuffer(vectorizeBytes)
    }

    private fun initialize2(vectorizeBytes: ByteArray) {
        val bis = ByteArrayInputStream(vectorizeBytes)
        var objectInput: ObjectInput? = null
        try {
            objectInput = ObjectInputStream(bis)
            val model = objectInput.readObject() as JavaSerializer.OcrVectorizeModel
            nativePtr = nativeInit2(model, model.upcs.size, model.weights.size)
        } catch (e: ClassNotFoundException) {
            PassioLog.e(this::class.java.simpleName, "OcrVectorizeModel class incorrect!")
        } finally {
            try {
                objectInput?.close()
            } catch (e: IOException) {
                e.printStackTrace()
            }
        }
    }

    fun matchText(text: String): List<Pair<String, Float>> {
        val cleanText = cleanText(text)
        if (cleanText.length < 3) {
            return emptyList()
        }

        val array = nativeRunOCRMatcher(nativePtr, cleanText, PACKAGED_FOOD_THRESHOLD)
        return array.toList()
    }

    fun searchText(text: String): List<Pair<String, Float>> {
        val cleanText = cleanNgrams(text)
        if (cleanText.length < 3) {
            return listOf()
        }
        val cutoff = if (cleanText.length < 6) 0f else SEARCH_CONFIDENCE_THRESHOLD
        val array = nativeSearchWithCutoff(
            nativePtr,
            cleanText,
            cutoff
        )
        return array.toList()
    }

    fun searchTextPopularity(text: String): List<Pair<String, Float>> {
        val cleanText = cleanText(text)
        if (cleanText.length < 3) {
            return emptyList()
        }
        val cutoff = if (cleanText.length < 6) 0f else SEARCH_CONFIDENCE_THRESHOLD
        val array = nativeSearchPopularityWithCutoff(
            nativePtr,
            cleanText,
            cutoff
        )
        return array.toList()
    }

    fun isInitialized(): Boolean = nativePtr != -1L

    private fun cleanText(text: String): String {
        return text.lowercase().replace(alphaRegex, "").replace("\\s+", "")
            .replace("\n", "")
    }

    private fun cleanNgrams(text: String): String {
        val words = text.lowercase().split(whitespaceRegex)
        val deduped = words.toSortedSet()
        val cleaned = deduped.map { it.replace(alphaRegex, "").replace("\n", "") }
        return cleaned.reduce { acc, s -> acc + s }
    }

    private external fun nativeInit(
        rows: Int,
        columns: Int,
        weights: List<Float>,
        vocab: List<Pair<String, Int>>,
        matrix: List<Array<Float>>,
        upcs: List<String>
    ): Long

    private external fun nativeInit2(
        model: JavaSerializer.OcrVectorizeModel,
        rows: Int,
        columns: Int
    ): Long

    private external fun nativeInitFlatbuffer(
        vectorizeBytes: ByteArray
    ): Long

    private external fun nativeRunOCRMatcher(
        nativePtr: Long,
        text: String,
        cutoff: Float
    ): Array<Pair<String, Float>>

    private external fun nativeSearchWithCutoff(
        nativePtr: Long,
        text: String,
        cutoff: Float
    ): Array<Pair<String, Float>>

    private external fun nativeSearchPopularityWithCutoff(
        nativePtr: Long,
        text: String,
        cutoff: Float
    ): Array<Pair<String, Float>>

}