/*
 * Copyright (c) 2025 Pointyware. Use of this software is governed by the Apache 2.0 license. See project root for full text.
 */

package org.pointyware.disco.training.interactors

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import org.pointyware.disco.entities.activations.Logistic
import org.pointyware.disco.entities.layers.DenseLayer
import org.pointyware.disco.entities.loss.MeanSquaredError
import org.pointyware.disco.entities.networks.ResidualSequentialNetwork
import org.pointyware.disco.entities.networks.SequentialNetwork
import org.pointyware.disco.entities.regularizers.RmsNorm
import org.pointyware.disco.training.data.ExerciseRepository
import org.pointyware.disco.training.data.Problem
import org.pointyware.disco.training.data.SpiralExerciseGenerator
import org.pointyware.disco.training.entities.Exercise
import org.pointyware.disco.training.entities.Measurement
import org.pointyware.disco.training.entities.SequentialStatistics
import org.pointyware.disco.training.entities.SequentialTrainer
import org.pointyware.disco.training.entities.Snapshot
import org.pointyware.disco.entities.math.key
import org.pointyware.disco.training.entities.optimizers.GradientDescent
import org.pointyware.disco.training.entities.optimizers.WarmRestartExponentialLearningRate
import kotlin.math.min

/**
 * Models basic training state and provides a handle to the trainer and network.
 */
data class TrainingState(
    val isTraining: Boolean = false,
    val epochsRemaining: Int = 0,
    val epochsElapsed: Int = 0,
    val networks: List<NetworkTrainingState> = emptyList(),
)

data class NetworkTrainingState(
    val elapsedEpochs: Int = 0,
    val trainer: SequentialTrainer,
    val snapshot: Snapshot
) {
    val networkDepth: Int
        get() = trainer.network.layers.size
}

/**
 * A Training Controller provides an interface to manage training of a model and exposes
 * properties to observe the training state.
 */
interface TrainingController {
    val state: StateFlow<TrainingState>

    /**
     * Sets the number of epochs to train the model. 0 means train indefinitely.
     */
    fun setEpochs(epochs: Int)

    /**
     * Starts the training process.
     */
    fun start()

    /**
     * Stop any ongoing training. The current epoch will be completed before stopping.
     */
    fun stop()

    /**
     * Stops the training process if it is currently running, and resets the training state.
     */
    fun reset()

    /**
     *
     */
    suspend fun test(case: Exercise): Double
}

val lossKey = 100L.key<Float>() // TODO: create separate statistics for each measurement to match cadence/collection characteristics?
val defaultMeasurements = listOf<Measurement<Float>>(Measurement.Intermediate("loss", lossKey))

/**
 *
 */
class TrainingControllerImpl(
    private val exerciseRepository: ExerciseRepository,
    private val trainingScope: CoroutineScope
): TrainingController {

    private var exercises: List<Exercise> = exerciseRepository.getExercises(Problem.XorProblem(0f, 1))

    private val spiralCases = SpiralExerciseGenerator(Problem.SpiralClassificationProblem(
        2f, 2f, 8f
    )).generate()

    private val hiddenWidth = 8
    // Create simple NN with 2 inputs, 1 hidden layer, and 1 output.
    private val _state = MutableStateFlow(TrainingState(
        isTraining = false,
        epochsRemaining = 0,
        networks = listOf(
            NetworkTrainingState(
                trainer = SequentialTrainer(
                    network = SequentialNetwork(listOf(
                        DenseLayer.create(2, hiddenWidth, Logistic),
                        DenseLayer.create(hiddenWidth, hiddenWidth, Logistic),
                        DenseLayer.create(hiddenWidth, hiddenWidth, Logistic),
                        DenseLayer.create(hiddenWidth, 1, Logistic)
                    )),
                    cases = spiralCases,
                    lossFunction = MeanSquaredError,
                    optimizer = GradientDescent(learningRate = WarmRestartExponentialLearningRate(0.1f)),
                    statistics = SequentialStatistics(defaultMeasurements)
                ),
                snapshot = Snapshot.empty
            ),

            NetworkTrainingState(
                trainer = SequentialTrainer(
                    network = ResidualSequentialNetwork(
                        listOf(
                            DenseLayer.create(2, hiddenWidth, Logistic),
                            DenseLayer.create(hiddenWidth, hiddenWidth, Logistic),
                            DenseLayer.create(hiddenWidth, hiddenWidth, Logistic),
                            DenseLayer.create(hiddenWidth, 1, Logistic),
                        ), 0, 2, RmsNorm(hiddenWidth)
                    ),
                    cases = spiralCases,
                    lossFunction = MeanSquaredError,
                    optimizer = GradientDescent(learningRate = WarmRestartExponentialLearningRate(0.1f)),
                    statistics = SequentialStatistics(defaultMeasurements)
                ),
                snapshot = Snapshot.empty
            )
        )
    ))
    override val state: StateFlow<TrainingState>
        get() = _state.asStateFlow()

    override fun setEpochs(epochs: Int) {
        _state.update {
            it.copy(
                epochsRemaining = epochs
            )
        }
    }

    private var trainingJob: Job? = null
    override fun start() {
        _state.update {
            if (it.isTraining) {
                it // No change if already training
            } else {
                it.copy(
                    isTraining = true,
                )
            }
        }
        trainingScope.launch { startTraining() }
    }

    private val trainingStep = 100
    private suspend fun startTraining() {
        trainingJob?.cancelAndJoin()
        trainingJob = trainingScope.launch {

            while (state.value.isTraining) {
                val epochsBeforeTraining = state.value.epochsRemaining
                val epochsToTrain = min(epochsBeforeTraining, trainingStep)

                val networkStatesPostTraining = state.value.networks.map {
                    val trained = it.trainer.train(iterations = epochsToTrain)
                    it.copy(
                        elapsedEpochs = it.elapsedEpochs + trained,
                        snapshot = it.trainer.statistics.createSnapshot()
                    )
                }

                val remaining = epochsBeforeTraining - epochsToTrain
                val totalEpochs = state.value.epochsElapsed + epochsToTrain

                _state.update { currentState ->
                    if (remaining <= 0) {
                        currentState.copy(
                            isTraining = false,
                            epochsRemaining = 0,
                            epochsElapsed = totalEpochs,
                            networks = networkStatesPostTraining
                        )
                    } else {
                        // Update the state to reflect remaining epochs
                        currentState.copy(
                            epochsRemaining = remaining,
                            epochsElapsed = totalEpochs,
                            networks = networkStatesPostTraining
                        )
                    }
                }
            }
        }
    }

    override fun stop() {
        _state.update {
            it.copy(
                isTraining = false
            )
        }
    }

    override suspend fun test(case: Exercise): Double {
//        // Test the trained network on a single case
//        val output = trainer.network.predict(case.input)
//        val error = trainer.lossFunction.compute(output, case.output)
//        println("Input: ${case.input} -> Error: %.4f Output: $output".format(error))
//        return error
        TODO("Ensure that trainers are working on the same problem set")
    }

    override fun reset() {
        _state.update {
            it.copy(
                isTraining = false,
                epochsRemaining = 0,
            )
        }
    }
}
