package kotlinx.coroutines.flow import kotlinx.coroutines.testing.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.testing.flow.* import kotlin.coroutines.* import kotlin.reflect.* import kotlin.test.* class FlowInvariantsTest : TestBase() { private fun runParametrizedTest( expectedException: KClass? = null, testBody: suspend (flowFactory: (suspend FlowCollector.() -> Unit) -> Flow) -> Unit ) = runTest { val r1 = runCatching { testBody { flow(it) } }.exceptionOrNull() check(r1, expectedException) reset() val r2 = runCatching { testBody { abstractFlow(it) } }.exceptionOrNull() check(r2, expectedException) } private fun abstractFlow(block: suspend FlowCollector.() -> Unit): Flow = object : AbstractFlow() { override suspend fun collectSafely(collector: FlowCollector) { collector.block() } } private fun check(exception: Throwable?, expectedException: KClass?) { if (expectedException != null && exception == null) fail("Expected $expectedException, but test completed successfully") if (expectedException != null && exception != null) assertTrue(expectedException.isInstance(exception)) if (expectedException == null && exception != null) throw exception } @Test fun testWithContextContract() = runParametrizedTest(IllegalStateException::class) { flow -> flow { withContext(NonCancellable) { emit(1) } }.collect { expectUnreached() } } @Test fun testWithDispatcherContractViolated() = runParametrizedTest(IllegalStateException::class) { flow -> flow { withContext(NamedDispatchers("foo")) { emit(1) } }.collect { expectUnreached() } } @Test fun testWithNameContractViolated() = runParametrizedTest(IllegalStateException::class) { flow -> flow { withContext(CoroutineName("foo")) { emit(1) } }.collect { expectUnreached() } } @Test fun testWithContextDoesNotChangeExecution() = runTest { val flow = flow { emit(NamedDispatchers.name()) }.flowOn(NamedDispatchers("original")) var result = "unknown" withContext(NamedDispatchers("misc")) { flow .flowOn(NamedDispatchers("upstream")) .launchIn(this + NamedDispatchers("consumer")) { onEach { result = it } }.join() } assertEquals("original", result) } @Test fun testScopedJob() = runParametrizedTest(IllegalStateException::class) { flow -> flow { emit(1) }.buffer(EmptyCoroutineContext, flow).collect { expect(1) } finish(2) } @Test fun testScopedJobWithViolation() = runParametrizedTest(IllegalStateException::class) { flow -> flow { emit(1) }.buffer(Dispatchers.Unconfined, flow).collect { expect(1) } finish(2) } @Test fun testMergeViolation() = runParametrizedTest { flow -> fun Flow.merge(other: Flow): Flow = flow { coroutineScope { launch { collect { value -> emit(value) } } other.collect { value -> emit(value) } } } fun Flow.trickyMerge(other: Flow): Flow = flow { coroutineScope { launch { collect { value -> coroutineScope { emit(value) } } } other.collect { value -> emit(value) } } } val flowInstance = flowOf(1) assertFailsWith { flowInstance.merge(flowInstance).toList() } assertFailsWith { flowInstance.trickyMerge(flowInstance).toList() } } @Test fun testNoMergeViolation() = runTest { fun Flow.merge(other: Flow): Flow = channelFlow { launch { collect { value -> send(value) } } other.collect { value -> send(value) } } fun Flow.trickyMerge(other: Flow): Flow = channelFlow { coroutineScope { launch { collect { value -> coroutineScope { send(value) } } } other.collect { value -> send(value) } } } val flow = flowOf(1) assertEquals(listOf(1, 1), flow.merge(flow).toList()) assertEquals(listOf(1, 1), flow.trickyMerge(flow).toList()) } @Test fun testScopedCoroutineNoViolation() = runParametrizedTest { flow -> fun Flow.buffer(): Flow = flow { coroutineScope { val channel = produce { collect { send(it) } } channel.consumeEach { emit(it) } } } assertEquals(listOf(1, 1), flowOf(1, 1).buffer().toList()) } private fun Flow.buffer(coroutineContext: CoroutineContext, flow: (suspend FlowCollector.() -> Unit) -> Flow): Flow = flow { coroutineScope { val channel = Channel() launch { collect { value -> channel.send(value) } channel.close() } launch(coroutineContext) { for (i in channel) { emit(i) } } } } @Test fun testEmptyCoroutineContextMap() = runTest { emptyContextTest { map { expect(it) it + 1 } } } @Test fun testEmptyCoroutineContextTransform() = runTest { emptyContextTest { transform { expect(it) emit(it + 1) } } } @Test fun testEmptyCoroutineContextTransformWhile() = runTest { emptyContextTest { transformWhile { expect(it) emit(it + 1) true } } } @Test fun testEmptyCoroutineContextViolationTransform() = runTest { try { emptyContextTest { transform { expect(it) withContext(Dispatchers.Unconfined) { emit(it + 1) } } } expectUnreached() } catch (e: IllegalStateException) { assertTrue(e.message!!.contains("Flow invariant is violated"), "But had: ${e.message}") finish(2) } } @Test fun testEmptyCoroutineContextViolationTransformWhile() = runTest { try { emptyContextTest { transformWhile { expect(it) withContext(Dispatchers.Unconfined) { emit(it + 1) } true } } expectUnreached() } catch (e: IllegalStateException) { assertTrue(e.message!!.contains("Flow invariant is violated")) finish(2) } } private suspend fun emptyContextTest(block: Flow.() -> Flow) { suspend fun collector(): Int { var result: Int = -1 channelFlow { send(1) }.block() .collect { expect(it) result = it } return result } val result = withEmptyContext { collector() } assertEquals(2, result) finish(3) } }