package kotlinx.coroutines.knit import kotlinx.coroutines.* import kotlinx.coroutines.internal.* import kotlinx.coroutines.scheduling.* import kotlinx.coroutines.testing.* import kotlinx.knit.test.* import java.util.concurrent.* import kotlin.test.* // helper function to dump exception to stdout for ease of debugging failed tests private inline fun outputException(name: String, block: () -> T): T = try { block() } catch (e: Throwable) { println("--- Failed test$name") e.printStackTrace(System.out) throw e } private const val SHUTDOWN_TIMEOUT = 5000L // 5 sec at most to wait private val OUT_ENABLED = systemProp("guide.tests.sout", false) fun test(name: String, block: () -> R): List = outputException(name) { try { captureOutput(name, stdoutEnabled = OUT_ENABLED) { log -> DefaultScheduler.usePrivateScheduler() DefaultExecutor.shutdownForTests(SHUTDOWN_TIMEOUT) resetCoroutineId() val threadsBefore = currentThreads() try { withVirtualTimeSource(log) { val result = block() require(result === Unit) { "Test 'main' shall return Unit" } } } finally { // the shutdown log.println("--- shutting down") DefaultScheduler.shutdown(SHUTDOWN_TIMEOUT) shutdownDispatcherPools(SHUTDOWN_TIMEOUT) DefaultExecutor.shutdownForTests(SHUTDOWN_TIMEOUT) // the last man standing -- cleanup all pending tasks } checkTestThreads(threadsBefore) // check thread if the main completed successfully } } finally { DefaultScheduler.restore() } } private fun shutdownDispatcherPools(timeout: Long) { val threads = arrayOfNulls(Thread.activeCount()) val n = Thread.enumerate(threads) for (i in 0 until n) { val thread = threads[i] if (thread is PoolThread) (thread.dispatcher.executor as ExecutorService).apply { shutdown() awaitTermination(timeout, TimeUnit.MILLISECONDS) shutdownNow().forEach { DefaultExecutor.enqueue(it) } } } } enum class SanitizeMode { NONE, ARBITRARY_TIME, FLEXIBLE_THREAD } private fun sanitize(s: String, mode: SanitizeMode): String { var res = s when (mode) { SanitizeMode.ARBITRARY_TIME -> { res = res.replace(Regex(" [0-9]+ ms"), " xxx ms") } SanitizeMode.FLEXIBLE_THREAD -> { res = res.replace(Regex("ForkJoinPool\\.commonPool-worker-[0-9]+"), "DefaultDispatcher") res = res.replace(Regex("ForkJoinPool-[0-9]+-worker-[0-9]+"), "DefaultDispatcher") res = res.replace(Regex("CommonPool-worker-[0-9]+"), "DefaultDispatcher") res = res.replace(Regex("DefaultDispatcher-worker-[0-9]+"), "DefaultDispatcher") res = res.replace(Regex("RxComputationThreadPool-[0-9]+"), "RxComputationThreadPool") res = res.replace(Regex("Test( worker)?"), "main") res = res.replace(Regex("@[0-9a-f]+"), "") // drop hex address } SanitizeMode.NONE -> {} } return res } private fun List.verifyCommonLines(expected: Array, mode: SanitizeMode = SanitizeMode.NONE) { val n = minOf(size, expected.size) for (i in 0 until n) { val exp = sanitize(expected[i], mode) val act = sanitize(get(i), mode) assertEquals(exp, act, "Line ${i + 1}") } } private fun List.checkEqualNumberOfLines(expected: Array) { if (size > expected.size) error("Expected ${expected.size} lines, but found $size. Unexpected line '${get(expected.size)}'") else if (size < expected.size) error("Expected ${expected.size} lines, but found $size") } fun List.verifyLines(vararg expected: String) = verify { verifyCommonLines(expected) checkEqualNumberOfLines(expected) } fun List.verifyLinesStartWith(vararg expected: String) = verify { verifyCommonLines(expected) assertTrue(expected.size <= size, "Number of lines") } fun List.verifyLinesArbitraryTime(vararg expected: String) = verify { verifyCommonLines(expected, SanitizeMode.ARBITRARY_TIME) checkEqualNumberOfLines(expected) } fun List.verifyLinesFlexibleThread(vararg expected: String) = verify { verifyCommonLines(expected, SanitizeMode.FLEXIBLE_THREAD) checkEqualNumberOfLines(expected) } fun List.verifyLinesStartUnordered(vararg expected: String) = verify { val expectedSorted = expected.sorted().toTypedArray() sorted().verifyLinesStart(*expectedSorted) } fun List.verifyExceptions(vararg expected: String) { val original = this val actual = ArrayList().apply { var except = false for (line in original) { when { !except && line.startsWith("\tat") -> except = true except && !line.startsWith("\t") && !line.startsWith("Caused by: ") -> except = false } if (!except) add(line) } } val n = minOf(actual.size, expected.size) for (i in 0 until n) { val exp = sanitize(expected[i], SanitizeMode.FLEXIBLE_THREAD) val act = sanitize(actual[i], SanitizeMode.FLEXIBLE_THREAD) assertEquals(exp, act, "Line ${i + 1}") } } fun List.verifyLinesStart(vararg expected: String) = verify { val n = minOf(size, expected.size) for (i in 0 until n) { val exp = sanitize(expected[i], SanitizeMode.FLEXIBLE_THREAD) val act = sanitize(get(i), SanitizeMode.FLEXIBLE_THREAD) assertEquals(exp, act.substring(0, minOf(act.length, exp.length)), "Line ${i + 1}") } checkEqualNumberOfLines(expected) } private inline fun List.verify(verification: () -> Unit) { try { verification() } catch (t: Throwable) { if (!OUT_ENABLED) { println("Printing [delayed] test output") forEach { println(it) } } throw t } }