package k.parallels

import k.common.*
import java.lang.Thread.*
import java.util.concurrent.atomic.AtomicReference

const val SMALL_PAUSE = 100L

fun waitFor(timeOut : Duration, message : String, code : () -> Boolean) {
    val timer = Timer()

    while (!code()) {
        sleep(SMALL_PAUSE)
        timer.checkTimeOut(timeOut, message)
    }
}

/**
 * Parallel execution
 */
infix fun <T, R> Iterable<T>.parallel(code : (item : T) -> R) : MutableList<R> {
    val error = AtomicReference<Throwable?>(null)
    val results = mutableListOf<R>()
    val sync = Sync()

    map {
        startVirtualThread {
            try {
                val res = code(it)

                sync {
                    results += res
                }
            }
            catch (e : Throwable) {
                error.compareAndSet(null, e)
            }
        }
    }.forEach { it.join() }

    error.get()?.let {
        throw it
    }

    return results
}

infix fun Int.parallel(code : (item : Int) -> Unit) {
    (1..this).parallel(code)
}

infix fun <T> Iterable<T>.parallel(code : (item : T) -> Unit) {
    val error = AtomicReference<Throwable?>(null)

    map {
        startVirtualThread {
            try {
                code(it)
            }
            catch (e : Throwable) {
                error.compareAndSet(null, e)
            }
        }
    }.forEach { it.join() }

    error.get()?.let {
        throw it
    }
}

/**
 * Parallel execution with pool
 */
fun <I, T> Iterable<I>.parallel(pool : Pool<T>, code : (item : I, resource : T) -> Unit) {
    val iterator = iterator()
    val threads = mutableListOf<Thread>()
    val sync = Sync()

    repeat(pool.count) {
        threads.add(startVirtualThread {
            val resource = pool.resource

            while (true) {
                val item = sync {
                    if (!iterator.hasNext())
                        return@startVirtualThread

                    iterator.next()
                }

                code(item, resource)
            }
        })
    }

    threads.forEach { it.join() }
}