package ai.passio.passiosdk.passiofood.recognition.filter

import ai.passio.passiosdk.core.utils.PassioLog
import ai.passio.passiosdk.passiofood.ClassificationCandidate
import kotlin.math.ln

internal class KalmanFilter(
    nClasses: Int,
    private val initialCovariance: Float = 0.1f,
    private val processNoise: Float = 0.3f,
    private val measurementNoise: Float = 0.3f
) : Filter(nClasses) {

    private val initialState = 1f / nClasses.toFloat()
    private var errorCovariance = FloatArray(nClasses) { initialCovariance }
    private var innovation = FloatArray(nClasses) { initialState }

    init {
        stateEstimate = FloatArray(nClasses) { initialState }
    }

    override fun step(observation: FloatArray): FloatArray {
        if (observation.size != nClasses) {
            PassioLog.w(this::class.java.simpleName, "Measurement array length does not match number of classes")
            return stateEstimate
        }

        divergence = divergenceEstimate(observation) / randomEntropy()

        for (i in 0 until nClasses) {
            errorCovariance[i] += processNoise

            val kalmanGain = errorCovariance[i] / (errorCovariance[i] + measurementNoise)
            innovation[i] = observation[i] - stateEstimate[i]
            stateEstimate[i] += kalmanGain * (observation[i] - stateEstimate[i])
            errorCovariance[i] *= 1 - kalmanGain
        }

        return stateEstimate
    }

    private fun divergenceEstimate(observation: FloatArray): Float {
        var nll = 0f
        for (i in 0 until nClasses) {
            nll += -ln(observation[i]) * stateEstimate[i]
        }
        return nll
    }

    private fun randomEntropy(): Float {
        return -ln(1f / nClasses)
    }

    override fun resetFilter() {
        stateEstimate = FloatArray(nClasses) { initialState }
        errorCovariance = FloatArray(nClasses) { initialCovariance }
    }


}