Не подтверждена Коммит dab18c08 создал по автору Mariia Skripchenko's avatar Mariia Skripchenko Зафиксировано автором GitHub
Просмотр файлов

KTOR-5216 Parse header with multiple challenges (#3277)

Parse header with multiple challenges
владелец bfde300e
......@@ -152,6 +152,41 @@ internal fun Application.authTestServer() {
call.respond("OK")
}
}
route("multiple") {
get("header") {
val token = call.request.headers[HttpHeaders.Authorization]
if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest, Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}
call.respond("OK")
}
get("headers") {
val token = call.request.headers[HttpHeaders.Authorization]
if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest"
)
call.response.header(
HttpHeaders.WWWAuthenticate,
"Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}
call.respond("OK")
}
}
}
}
}
......
......@@ -54,14 +54,26 @@ public class Auth private constructor(
val candidateProviders = HashSet(plugin.providers)
while (call.response.status == HttpStatusCode.Unauthorized) {
val headerValue = call.response.headers[HttpHeaders.WWWAuthenticate]
val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate)
val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList()
val authHeader = headerValue?.let { parseAuthorizationHeader(headerValue) }
val provider = when {
authHeader == null && candidateProviders.size == 1 -> candidateProviders.first()
authHeader == null -> return@intercept call
else -> candidateProviders.find { it.isApplicable(authHeader) } ?: return@intercept call
var providerOrNull: AuthProvider? = null
var authHeader: HttpAuthHeader? = null
when {
authHeaders.isEmpty() && candidateProviders.size == 1 -> {
providerOrNull = candidateProviders.first()
}
authHeaders.isEmpty() -> return@intercept call
else -> authHeader = authHeaders.find { header ->
providerOrNull = candidateProviders.find { it.isApplicable(header) }
providerOrNull != null
}
}
val provider = providerOrNull ?: return@intercept call
if (!provider.refreshToken(call.response)) return@intercept call
candidateProviders.remove(provider)
......
......@@ -535,4 +535,73 @@ class AuthTest : ClientLoader() {
assertEquals(2, loadCount)
}
}
@Test
fun testMultipleChallengesInHeader() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseOneHeader = client.get("$TEST_SERVER/auth/multiple/header").bodyAsText()
assertEquals("OK", responseOneHeader)
}
}
@Test
fun testMultipleChallengesInHeaders() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseMultipleHeaders = client.get("$TEST_SERVER/auth/multiple/headers").bodyAsText()
assertEquals("OK", responseMultipleHeaders)
}
}
@Test
fun testMultipleChallengesInHeaderUnauthorized() = clientTests {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/header")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers[HttpHeaders.WWWAuthenticate]?.also {
assertTrue { it.contains("Bearer") }
assertTrue { it.contains("Basic") }
assertTrue { it.contains("Digest") }
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}
@Test
fun testMultipleChallengesInMultipleHeadersUnauthorized() = clientTests(listOf("Js")) {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/headers")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers.getAll(HttpHeaders.WWWAuthenticate)?.let {
assertEquals(2, it.size)
it.joinToString().let { header ->
assertTrue { header.contains("Basic") }
assertTrue { header.contains("Digest") }
assertTrue { header.contains("Bearer") }
}
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}
}
......@@ -1114,6 +1114,7 @@ public final class io/ktor/http/auth/HttpAuthHeader$Single : io/ktor/http/auth/H
public final class io/ktor/http/auth/HttpAuthHeaderKt {
public static final fun parseAuthorizationHeader (Ljava/lang/String;)Lio/ktor/http/auth/HttpAuthHeader;
public static final fun parseAuthorizationHeaders (Ljava/lang/String;)Ljava/util/List;
}
public final class io/ktor/http/content/ByteArrayContent : io/ktor/http/content/OutgoingContent$ByteArrayContent {
......
......@@ -8,7 +8,6 @@ import io.ktor.http.*
import io.ktor.http.parsing.*
import io.ktor.util.*
import io.ktor.utils.io.charsets.*
import kotlin.native.concurrent.*
private val TOKEN_EXTRA = setOf('!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~')
private val TOKEN68_EXTRA = setOf('-', '.', '_', '~', '+', '/')
......@@ -19,12 +18,14 @@ private val escapeRegex: Regex = "\\\\.".toRegex()
* Parses an authorization header [headerValue] into a [HttpAuthHeader].
* @return [HttpAuthHeader] or `null` if argument string is blank.
* @throws [ParseException] on invalid header
*
* @see [parseAuthorizationHeaders]
*/
public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
var index = 0
index = headerValue.skipSpaces(index)
var tokenStartIndex = index
val tokenStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}
......@@ -32,7 +33,6 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
// Auth scheme
val authScheme = headerValue.substring(tokenStartIndex until index)
index = headerValue.skipSpaces(index)
tokenStartIndex = index
if (authScheme.isBlank()) {
return null
......@@ -42,28 +42,114 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
return HttpAuthHeader.Parameterized(authScheme, emptyList())
}
val token68 = matchToken68(headerValue, index)
if (token68 != null) {
return HttpAuthHeader.Single(authScheme, token68)
val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
if (token68EndIndex == headerValue.length) {
return HttpAuthHeader.Single(authScheme, token68)
}
}
val parameters = matchParameters(headerValue, tokenStartIndex)
return HttpAuthHeader.Parameterized(authScheme, parameters)
val parameters = mutableMapOf<String, String>()
val endIndex = matchParameters(headerValue, index, parameters)
return if (endIndex == -1) HttpAuthHeader.Parameterized(authScheme, parameters) else
throw ParseException("Function parseAuthorizationHeader can parse only one header")
}
/**
* Parses an authorization header [headerValue] into a list of [HttpAuthHeader].
* @return a list of [HttpAuthHeader]
* @throws [ParseException] on invalid header
*/
@InternalAPI
public fun parseAuthorizationHeaders(headerValue: String): List<HttpAuthHeader> {
var index = 0
val headers = mutableListOf<HttpAuthHeader>()
while (index != -1) {
index = parseAuthorizationHeader(headerValue, index, headers)
}
return headers
}
private fun matchParameters(headerValue: String, startIndex: Int): Map<String, String> {
val result = mutableMapOf<String, String>()
private fun parseAuthorizationHeader(
headerValue: String,
startIndex: Int,
headers: MutableList<HttpAuthHeader>
): Int {
var index = headerValue.skipSpaces(startIndex)
// Auth scheme
val schemeStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}
val authScheme = headerValue.substring(schemeStartIndex until index)
if (authScheme.isBlank()) {
throw ParseException("Invalid authScheme value: it should be token, can't be blank")
}
index = headerValue.skipSpaces(index)
nextChallengeIndex(headers, HttpAuthHeader.Parameterized(authScheme, emptyList()), index, headerValue)?.let {
return it
}
val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
nextChallengeIndex(headers, HttpAuthHeader.Single(authScheme, token68), token68EndIndex, headerValue)?.let {
return it
}
}
val parameters = mutableMapOf<String, String>()
val nextIndexChallenge = matchParameters(headerValue, index, parameters)
headers.add(HttpAuthHeader.Parameterized(authScheme, parameters))
return nextIndexChallenge
}
/**
* Check for the ending of the current challenge in a header
* @return -1 if at the end of the header
* @return null if the challenge is not ended
* @return a positive number - the index of the beginning of the next challenge
*/
private fun nextChallengeIndex(
headers: MutableList<HttpAuthHeader>,
header: HttpAuthHeader,
index: Int,
headerValue: String
): Int? {
if (index == headerValue.length || headerValue[index] == ',') {
headers.add(header)
return when {
index == headerValue.length -> -1
headerValue[index] == ',' -> index + 1
else -> error("") // unreachable code
}
}
return null
}
private fun matchParameters(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
var index = startIndex
while (index > 0 && index < headerValue.length) {
index = matchParameter(headerValue, index, result)
index = headerValue.skipDelimiter(index, ',')
val nextIndex = matchParameter(headerValue, index, parameters)
if (nextIndex == index) {
return index
} else {
index = headerValue.skipDelimiter(nextIndex, ',')
}
}
return result
return index
}
private fun matchParameter(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
private fun matchParameter(
headerValue: String,
startIndex: Int,
parameters: MutableMap<String, String>
): Int {
val keyStart = headerValue.skipSpaces(startIndex)
var index = keyStart
......@@ -71,15 +157,15 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}
val key = headerValue.substring(keyStart until index)
// Take '='
// Check if new challenge
index = headerValue.skipSpaces(index)
if (index >= headerValue.length || headerValue[index] != '=') {
throw ParseException("Expected `=` after parameter key '$key': $headerValue")
if (index == headerValue.length || headerValue[index] != '=') {
return startIndex
}
// Take '='
index++
index = headerValue.skipSpaces(index)
......@@ -116,8 +202,8 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut
return index
}
private fun matchToken68(headerValue: String, startIndex: Int): String? {
var index = startIndex
private fun matchToken68(headerValue: String, startIndex: Int): Int {
var index = headerValue.skipSpaces(startIndex)
while (index < headerValue.length && headerValue[index].isToken68()) {
index++
......@@ -127,12 +213,7 @@ private fun matchToken68(headerValue: String, startIndex: Int): String? {
index++
}
val onlySpaceRemaining = (index until headerValue.length).all { headerValue[it] == ' ' }
if (onlySpaceRemaining) {
return headerValue.substring(startIndex until index)
}
return null
return headerValue.skipSpaces(index)
}
/**
......@@ -355,13 +436,11 @@ private fun String.unescaped() = replace(escapeRegex) { it.value.takeLast(1) }
private fun String.skipDelimiter(startIndex: Int, delimiter: Char): Int {
var index = skipSpaces(startIndex)
while (index < length && this[index] != delimiter) {
index++
}
if (index == length) return -1
index++
if (this[index] != delimiter)
throw ParseException("Expected delimiter $delimiter at position $index, but found ${this[index]}")
index++
return skipSpaces(index)
}
......
......@@ -5,23 +5,28 @@
package io.ktor.tests.auth
import io.ktor.http.auth.*
import io.ktor.util.*
import kotlin.random.*
import kotlin.test.*
class AuthorizeHeaderParserTest {
@Test fun empty() {
@Test
fun empty() {
testParserParameterized("Basic", emptyMap(), "Basic")
}
@Test fun emptyWithTrailingSpaces() {
@Test
fun emptyWithTrailingSpaces() {
testParserParameterized("Basic", emptyMap(), "Basic ")
}
@Test fun singleSimple() {
@Test
fun singleSimple() {
testParserSingle("Basic", "abc==", "Basic abc==")
}
@Test fun testParameterizedSimple() {
@Test
fun testParameterizedSimple() {
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1")
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a =1")
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a = 1")
......@@ -30,7 +35,8 @@ class AuthorizeHeaderParserTest {
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1 ")
}
@Test fun testParameterizedSimpleTwoParams() {
@Test
fun testParameterizedSimpleTwoParams() {
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1, b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1,b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 ,b=2")
......@@ -38,19 +44,53 @@ class AuthorizeHeaderParserTest {
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 , b=2 ")
}
@Test fun testParameterizedQuoted() {
@Test
fun testParameterizedQuoted() {
testParserParameterized("Basic", mapOf("a" to "1 2"), "Basic a=\"1 2\"")
}
@Test fun testParameterizedQuotedEscaped() {
@Test
fun testParameterizedQuotedEscaped() {
testParserParameterized("Basic", mapOf("a" to "1 \" 2"), "Basic a=\"1 \\\" 2\"")
testParserParameterized("Basic", mapOf("a" to "1 A 2"), "Basic a=\"1 \\A 2\"")
}
@Test fun testParameterizedQuotedEscapedInTheMiddle() {
@Test
fun testParameterizedQuotedEscapedInTheMiddle() {
testParserParameterized("Basic", mapOf("a" to "1 \" 2", "b" to "2"), "Basic a=\"1 \\\" 2\", b= 2")
}
@Test
fun testMultipleChallengesParameters() {
val expected = listOf(
HttpAuthHeader.Parameterized("Digest", emptyMap()),
HttpAuthHeader.Parameterized("Bearer", mapOf("1" to "2", "3" to "4")),
HttpAuthHeader.Parameterized("Basic", emptyMap()),
)
testParserMultipleChallenges(expected, "Digest, Bearer 1 = 2, 3=4, Basic ")
}
@Test
fun testMultipleChallengesSingle() {
val expected = listOf(
HttpAuthHeader.Single("Bearer", "abc=="),
HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")),
HttpAuthHeader.Single("Basic", "def==="),
HttpAuthHeader.Parameterized("Digest", emptyMap())
)
testParserMultipleChallenges(expected, "Bearer abc==, Bearer abc=def, Basic def===, Digest")
}
@Test
fun testMultipleChallengesAllHeaders() {
val expected = listOf(
HttpAuthHeader.Parameterized("Basic", emptyMap()),
HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")),
HttpAuthHeader.Single("Digest", "abc==")
)
testParserMultipleChallenges(expected, "Basic, Bearer abc=def,Digest abc==")
}
private fun testParserSingle(scheme: String, value: String, headerValue: String) {
val actual = parseAuthorizationHeader(headerValue)!!
......@@ -75,11 +115,32 @@ class AuthorizeHeaderParserTest {
}
}
@OptIn(InternalAPI::class)
private fun testParserMultipleChallenges(expected: List<HttpAuthHeader>, headerValue: String) {
val actual = parseAuthorizationHeaders(headerValue)
assertEquals(expected.size, actual.size)
(expected zip actual).forEach { (expectedHeader, actualHeader) ->
if (expectedHeader is HttpAuthHeader.Single) {
assertIs<HttpAuthHeader.Single>(actualHeader)
assertEquals(expectedHeader.blob, actualHeader.blob)
}
if (expectedHeader is HttpAuthHeader.Parameterized) {
assertIs<HttpAuthHeader.Parameterized>(actualHeader)
assertEquals(
expectedHeader.parameters.associateBy({ it.name }, { it.value }),
actualHeader.parameters.associateBy({ it.name }, { it.value })
)
}
}
}
private fun Random.nextString(
length: Int,
possible: Iterable<Char> = ('a'..'z') + ('A'..'Z') + ('0'..'9')
) = possible.toList().let { possibleElements ->
(0..length - 1).map { nextFrom(possibleElements) }.joinToString("")
(0 until length).map { nextFrom(possibleElements) }.joinToString("")
}
private fun Random.nextString(length: Int, possible: String) = nextString(length, possible.toList())
......
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать