package kotlinx.coroutines.internal import kotlinx.coroutines.* import kotlin.coroutines.* @JvmField internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") // Used when there are >= 2 active elements in the context @Suppress("UNCHECKED_CAST") private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { private val values = arrayOfNulls(n) private val elements = arrayOfNulls>(n) private var i = 0 fun append(element: ThreadContextElement<*>, value: Any?) { values[i] = value elements[i++] = element as ThreadContextElement } fun restore(context: CoroutineContext) { for (i in elements.indices.reversed()) { elements[i]!!.restoreThreadContext(context, values[i]) } } } // Counts ThreadContextElements in the context // Any? here is Int | ThreadContextElement (when count is one) private val countAll = fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { if (element is ThreadContextElement<*>) { val inCount = countOrElement as? Int ?: 1 return if (inCount == 0) element else inCount + 1 } return countOrElement } // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one private val findOne = fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { if (found != null) return found return element as? ThreadContextElement<*> } // Updates state for ThreadContextElements in the context using the given ThreadState private val updateState = fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { if (element is ThreadContextElement<*>) { state.append(element, element.updateThreadContext(state.context)) } return state } internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! // countOrElement is pre-cached in dispatched continuation // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { @Suppress("NAME_SHADOWING") val countOrElement = countOrElement ?: threadContextElements(context) @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") return when { countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements // ^^^ identity comparison for speed, we know zero always has the same identity countOrElement is Int -> { // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values context.fold(ThreadState(context, countOrElement), updateState) } else -> { // fast path for one ThreadContextElement (no allocations, no additional context scan) @Suppress("UNCHECKED_CAST") val element = countOrElement as ThreadContextElement element.updateThreadContext(context) } } } internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { when { oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements oldState is ThreadState -> { // slow path with multiple stored ThreadContextElements oldState.restore(context) } else -> { // fast path for one ThreadContextElement, but need to find it @Suppress("UNCHECKED_CAST") val element = context.fold(null, findOne) as ThreadContextElement element.restoreThreadContext(context, oldState) } } } // top-level data class for a nicer out-of-the-box toString representation and class name @PublishedApi internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key> internal class ThreadLocalElement( private val value: T, private val threadLocal: ThreadLocal ) : ThreadContextElement { override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal) override fun updateThreadContext(context: CoroutineContext): T { val oldState = threadLocal.get() threadLocal.set(value) return oldState } override fun restoreThreadContext(context: CoroutineContext, oldState: T) { threadLocal.set(oldState) } // this method is overridden to perform value comparison (==) on key override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext { return if (this.key == key) EmptyCoroutineContext else this } // this method is overridden to perform value comparison (==) on key public override operator fun get(key: CoroutineContext.Key): E? = @Suppress("UNCHECKED_CAST") if (this.key == key) this as E else null override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)" }