Не подтверждена Коммит bbf42373 создал по автору Aleksei Tirman's avatar Aleksei Tirman Зафиксировано автором GitHub
Просмотр файлов

KTOR-8312 Make the clearToken behavior consistent (#4735)

владелец fc0fc4f7
......@@ -4,82 +4,103 @@
package io.ktor.client.plugins.auth.providers
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import kotlin.concurrent.Volatile
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
internal class AuthTokenHolder<T>(
private val loadTokens: suspend () -> T?
) {
private val refreshTokensDeferred = atomic<CompletableDeferred<T?>?>(null)
private val loadTokensDeferred = atomic<CompletableDeferred<T?>?>(null)
internal class AuthTokenHolder<T>(private val loadTokens: suspend () -> T?) {
internal fun clearToken() {
loadTokensDeferred.value = null
refreshTokensDeferred.value = null
}
@Volatile private var value: T? = null
internal suspend fun loadToken(): T? {
var deferred: CompletableDeferred<T?>?
lateinit var newDeferred: CompletableDeferred<T?>
while (true) {
deferred = loadTokensDeferred.value
val newValue = deferred ?: CompletableDeferred()
if (loadTokensDeferred.compareAndSet(deferred, newValue)) {
newDeferred = newValue
break
}
}
@Volatile private var isLoadRequest = false
// if there's already a pending loadTokens(), just wait for it to complete
if (deferred != null) {
return deferred.await()
}
private val mutex = Mutex()
/**
* Exist only for testing
*/
internal fun get(): T? = value
try {
val newTokens = loadTokens()
/**
* Returns a cached value if any. Otherwise, computes a value using [loadTokens] and caches it.
* Only one [loadToken] call can be executed at a time. The other calls are suspended and have no effect on the cached value.
*/
internal suspend fun loadToken(): T? {
if (value != null) return value // Hot path
val prevValue = value
// [loadTokensDeferred.value] could be null by now (if clearToken() was called while
// suspended), which is why we are using [newDeferred] to complete the suspending callback.
newDeferred.complete(newTokens)
return if (coroutineContext[SetTokenContext] != null) { // Already locked by setToken
value = loadTokens()
value
} else {
mutex.withLock {
isLoadRequest = true
try {
if (prevValue == value) { // Raced first
value = loadTokens()
}
} finally {
isLoadRequest = false
}
return newTokens
} catch (cause: Throwable) {
newDeferred.completeExceptionally(cause)
loadTokensDeferred.compareAndSet(newDeferred, null)
throw cause
value
}
}
}
private class SetTokenContext : CoroutineContext.Element {
override val key: CoroutineContext.Key<*>
get() = SetTokenContext
companion object : CoroutineContext.Key<SetTokenContext>
}
private val setTokenMarker = SetTokenContext()
/**
* Replaces the current cached value with one computed with [block].
* Only one [loadToken] or [setToken] call can be executed at a time,
* although the resumed [setToken] call recomputes the value cached by [loadToken].
*/
internal suspend fun setToken(block: suspend () -> T?): T? {
var deferred: CompletableDeferred<T?>?
lateinit var newDeferred: CompletableDeferred<T?>
while (true) {
deferred = refreshTokensDeferred.value
val newValue = deferred ?: CompletableDeferred()
if (refreshTokensDeferred.compareAndSet(deferred, newValue)) {
newDeferred = newValue
break
val prevValue = value
val lockedByLoad = isLoadRequest
return mutex.withLock {
if (prevValue == value || lockedByLoad) { // Raced first
val newValue = withContext(coroutineContext + setTokenMarker) {
block()
}
if (newValue != null) {
value = newValue
}
}
value
}
}
try {
val newToken = if (deferred == null) {
val newTokens = block()
// [refreshTokensDeferred.value] could be null by now (if clearToken() was called while
// suspended), which is why we are using [newDeferred] to complete the suspending callback.
newDeferred.complete(newTokens)
refreshTokensDeferred.value = null
newTokens
} else {
deferred.await()
/**
* Resets the cached value.
*/
@OptIn(DelicateCoroutinesApi::class)
internal fun clearToken() {
if (mutex.tryLock()) {
value = null
mutex.unlock()
} else {
GlobalScope.launch {
mutex.withLock {
value = null
}
}
loadTokensDeferred.value = CompletableDeferred(newToken)
return newToken
} catch (cause: Throwable) {
newDeferred.completeExceptionally(cause)
refreshTokensDeferred.compareAndSet(newDeferred, null)
throw cause
}
}
}
......@@ -7,39 +7,41 @@ package io.ktor.client.plugins.auth
import io.ktor.client.plugins.auth.providers.*
import kotlinx.coroutines.*
import kotlinx.coroutines.test.runTest
import kotlin.coroutines.CoroutineContext
import kotlin.test.*
class AuthTokenHolderTest {
@Test
@OptIn(DelicateCoroutinesApi::class)
fun testSetTokenCalledOnce() = runTest {
val holder = AuthTokenHolder<BearerTokens> { TODO() }
fun testOnlyOneSetTokenCallComputesBlock() = runTest {
val holder = AuthTokenHolder<Int> { fail() }
val monitor = Job()
var firstExecuted = false
var secondExecuted = false
var firstCalled = false
val first = GlobalScope.launch(Dispatchers.Unconfined) {
holder.setToken {
firstExecuted = true
monitor.join()
BearerTokens("1", "2")
firstCalled = true
delay(100)
1
}
}
var secondCalled = false
val second = GlobalScope.launch(Dispatchers.Unconfined) {
delay(50)
holder.setToken {
secondExecuted = true
BearerTokens("1", "2")
secondCalled = true
2
}
}
monitor.complete()
first.join()
second.join()
assertTrue(firstExecuted)
assertFalse(secondExecuted)
val token = holder.loadToken()
assertEquals(token, 1)
assertTrue { firstCalled }
assertFalse { secondCalled }
}
@Test
......@@ -77,7 +79,7 @@ class AuthTokenHolderTest {
}
monitor.join()
BearerTokens("1", "2")
1
}
val first = GlobalScope.async(Dispatchers.Unconfined) {
......@@ -90,9 +92,10 @@ class AuthTokenHolderTest {
}
monitor.complete()
assertNotNull(first.await())
assertNotNull(second.await())
assertEquals(1, first.await())
second.await()
assertTrue(clearTokenCalled)
assertNull(holder.get())
}
@Test
......@@ -101,7 +104,7 @@ class AuthTokenHolderTest {
val monitor = Job()
var clearTokenCalled = false
val holder = AuthTokenHolder<BearerTokens> {
val holder = AuthTokenHolder<Int> {
fail("loadTokens argument function shouldn't be invoked")
}
......@@ -112,7 +115,7 @@ class AuthTokenHolderTest {
delay(10)
}
monitor.join()
BearerTokens("1", "2")
1
}
}
......@@ -122,9 +125,10 @@ class AuthTokenHolderTest {
}
monitor.complete()
assertNotNull(first.await())
assertNotNull(second.await())
assertEquals(1, first.await())
second.await()
assertTrue(clearTokenCalled)
assertNull(holder.get())
}
@Test
......@@ -149,4 +153,139 @@ class AuthTokenHolderTest {
assertFailsWith<IllegalStateException> { holder.setToken { throw IllegalStateException("First call") } }
assertEquals("token", holder.setToken { "token" })
}
internal class MyContext(val value: Int) : CoroutineContext.Element {
override val key: CoroutineContext.Key<*>
get() = MyContext
companion object : CoroutineContext.Key<MyContext>
}
@OptIn(DelicateCoroutinesApi::class)
@Test
fun firstLoadTokenCallComputesBlockAndSetsValue() = runTest {
val holder = AuthTokenHolder {
coroutineScope {
val context = coroutineContext[MyContext]
assertNotNull(context)
context.value
}
}
val first = GlobalScope.async(Dispatchers.Unconfined) {
delay(50)
withContext(MyContext(1)) {
holder.loadToken()
}
}
val second = GlobalScope.async(Dispatchers.Unconfined) {
withContext(MyContext(2)) {
holder.loadToken()
}
}
assertEquals(2, first.await())
assertEquals(2, second.await())
assertEquals(2, holder.get())
assertEquals(2, holder.loadToken())
}
@OptIn(DelicateCoroutinesApi::class)
@Test
fun firstSetTokenCallComputesBlockAndSetsValue() = runTest {
val holder = AuthTokenHolder<Int> {
fail()
}
val first = GlobalScope.async(Dispatchers.Unconfined) {
delay(50)
holder.setToken {
1
}
}
val second = GlobalScope.async(Dispatchers.Unconfined) {
holder.setToken {
delay(100)
2
}
}
assertEquals(2, first.await())
assertEquals(2, second.await())
assertEquals(2, holder.get())
assertEquals(2, holder.loadToken())
}
@Test
@OptIn(DelicateCoroutinesApi::class)
fun testClearCoroutineResetsCachedValue() = runTest {
val holder = AuthTokenHolder {
delay(200)
1
}
val loadToken = GlobalScope.async(Dispatchers.Unconfined) {
holder.loadToken()
}
val setToken = GlobalScope.async(Dispatchers.Unconfined) {
delay(50)
holder.setToken {
delay(100)
2
}
}
val clear = GlobalScope.async(Dispatchers.Unconfined) {
delay(100)
holder.clearToken()
}
assertEquals(1, loadToken.await())
assertEquals(2, setToken.await())
clear.await()
assertNull(holder.get())
}
@Test
@OptIn(DelicateCoroutinesApi::class)
fun lockedSetTokenByLoadTokenSetsValue() = runTest {
val holder = AuthTokenHolder {
delay(200)
1
}
val loadToken = GlobalScope.async(Dispatchers.Unconfined) {
holder.loadToken()
}
val setToken = GlobalScope.async(Dispatchers.Unconfined) {
delay(100)
holder.setToken {
2
}
}
assertEquals(1, loadToken.await())
assertEquals(2, setToken.await())
assertEquals(2, holder.loadToken())
}
@Test
@OptIn(DelicateCoroutinesApi::class)
fun loadTokensCanBeCalledInSetTokenBlock() = runTest {
val holder = AuthTokenHolder {
1
}
val setToken = GlobalScope.async(Dispatchers.Unconfined) {
holder.setToken {
1 + holder.loadToken()!!
}
}
assertEquals(2, setToken.await())
assertEquals(2, holder.loadToken())
}
}
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать