Не подтверждена Коммит e4fd2518 создал по автору Osip Fatkullin's avatar Osip Fatkullin Зафиксировано автором GitHub
Просмотр файлов

KTOR-6970 Darwin, Java, JS: Propagate Sec-WebSocket-Protocol header (#4633)

владелец 15f09214
...@@ -73,7 +73,7 @@ internal class JsClientEngine( ...@@ -73,7 +73,7 @@ internal class JsClientEngine(
headers: Headers headers: Headers
): WebSocket { ): WebSocket {
val protocolHeaderNames = headers.names().filter { headerName -> val protocolHeaderNames = headers.names().filter { headerName ->
headerName.equals("sec-websocket-protocol", true) headerName.equals(HttpHeaders.SecWebSocketProtocol, ignoreCase = true)
} }
val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray() val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray()
return when { return when {
...@@ -108,10 +108,13 @@ internal class JsClientEngine( ...@@ -108,10 +108,13 @@ internal class JsClientEngine(
throw cause throw cause
} }
val protocol = socket.protocol.takeIf { it.isNotEmpty() }
val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty
return HttpResponseData( return HttpResponseData(
HttpStatusCode.SwitchingProtocols, HttpStatusCode.SwitchingProtocols,
requestTime, requestTime,
Headers.Empty, headers,
HttpProtocolVersion.HTTP_1_1, HttpProtocolVersion.HTTP_1_1,
session, session,
callContext callContext
......
...@@ -82,7 +82,7 @@ internal class JsClientEngine( ...@@ -82,7 +82,7 @@ internal class JsClientEngine(
headers: Headers headers: Headers
): WebSocket { ): WebSocket {
val protocolHeaderNames = headers.names().filter { headerName -> val protocolHeaderNames = headers.names().filter { headerName ->
headerName.equals("sec-websocket-protocol", true) headerName.equals(HttpHeaders.SecWebSocketProtocol, ignoreCase = true)
} }
val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray() val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray()
return when { return when {
...@@ -116,10 +116,13 @@ internal class JsClientEngine( ...@@ -116,10 +116,13 @@ internal class JsClientEngine(
val session = JsWebSocketSession(callContext, socket) val session = JsWebSocketSession(callContext, socket)
val protocol = socket.protocol.takeIf { it.isNotEmpty() }
val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty
return HttpResponseData( return HttpResponseData(
HttpStatusCode.SwitchingProtocols, HttpStatusCode.SwitchingProtocols,
requestTime, requestTime,
Headers.Empty, headers,
HttpProtocolVersion.HTTP_1_1, HttpProtocolVersion.HTTP_1_1,
session, session,
callContext callContext
......
/* /*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/ */
package io.ktor.client.engine.darwin package io.ktor.client.engine.darwin
...@@ -7,11 +7,12 @@ package io.ktor.client.engine.darwin ...@@ -7,11 +7,12 @@ package io.ktor.client.engine.darwin
import io.ktor.client.engine.darwin.internal.* import io.ktor.client.engine.darwin.internal.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.util.collections.* import io.ktor.util.collections.*
import kotlinx.cinterop.* import kotlinx.cinterop.UnsafeNumber
import kotlinx.coroutines.* import kotlinx.coroutines.CompletableDeferred
import platform.Foundation.* import platform.Foundation.*
import platform.darwin.* import platform.darwin.NSObject
import kotlin.coroutines.* import kotlin.collections.set
import kotlin.coroutines.CoroutineContext
private const val HTTP_REQUESTS_INITIAL_CAPACITY = 32 private const val HTTP_REQUESTS_INITIAL_CAPACITY = 32
private const val WS_REQUESTS_INITIAL_CAPACITY = 16 private const val WS_REQUESTS_INITIAL_CAPACITY = 16
...@@ -77,7 +78,7 @@ public class KtorNSURLSessionDelegate( ...@@ -77,7 +78,7 @@ public class KtorNSURLSessionDelegate(
didOpenWithProtocol: String? didOpenWithProtocol: String?
) { ) {
val wsSession = webSocketSessions[webSocketTask] ?: return val wsSession = webSocketSessions[webSocketTask] ?: return
wsSession.didOpen() wsSession.didOpen(didOpenWithProtocol)
} }
override fun URLSession( override fun URLSession(
......
/* /*
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/ */
package io.ktor.client.engine.darwin.internal package io.ktor.client.engine.darwin.internal
...@@ -10,13 +10,20 @@ import io.ktor.http.* ...@@ -10,13 +10,20 @@ import io.ktor.http.*
import io.ktor.util.date.* import io.ktor.util.date.*
import io.ktor.utils.io.core.* import io.ktor.utils.io.core.*
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.cinterop.* import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.UnsafeNumber
import kotlinx.cinterop.convert
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.* import kotlinx.coroutines.channels.Channel
import kotlinx.io.* import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.io.readByteArray
import platform.Foundation.* import platform.Foundation.*
import platform.darwin.* import platform.darwin.NSInteger
import kotlin.coroutines.* import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
@OptIn(UnsafeNumber::class, ExperimentalForeignApi::class) @OptIn(UnsafeNumber::class, ExperimentalForeignApi::class)
internal class DarwinWebsocketSession( internal class DarwinWebsocketSession(
...@@ -157,11 +164,13 @@ internal class DarwinWebsocketSession( ...@@ -157,11 +164,13 @@ internal class DarwinWebsocketSession(
coroutineContext.cancel() coroutineContext.cancel()
} }
fun didOpen() { fun didOpen(protocol: String?) {
val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty
val response = HttpResponseData( val response = HttpResponseData(
task.getStatusCode()?.let { HttpStatusCode.fromValue(it) } ?: HttpStatusCode.SwitchingProtocols, task.getStatusCode()?.let { HttpStatusCode.fromValue(it) } ?: HttpStatusCode.SwitchingProtocols,
requestTime, requestTime,
Headers.Empty, headers,
HttpProtocolVersion.HTTP_1_1, HttpProtocolVersion.HTTP_1_1,
this, this,
coroutineContext coroutineContext
...@@ -177,7 +186,7 @@ internal class DarwinWebsocketSession( ...@@ -177,7 +186,7 @@ internal class DarwinWebsocketSession(
// KTOR-7363 We want to proceed with the request if we get 401 Unauthorized status code // KTOR-7363 We want to proceed with the request if we get 401 Unauthorized status code
if (task.getStatusCode() == HttpStatusCode.Unauthorized.value) { if (task.getStatusCode() == HttpStatusCode.Unauthorized.value) {
didOpen() didOpen(protocol = null)
socketJob.complete() socketJob.complete()
return return
} }
......
/* /*
* Copyright 2014-2021 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/ */
package io.ktor.client.engine.java package io.ktor.client.engine.java
...@@ -15,14 +15,18 @@ import io.ktor.utils.io.* ...@@ -15,14 +15,18 @@ import io.ktor.utils.io.*
import io.ktor.utils.io.core.* import io.ktor.utils.io.core.*
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.* import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.future.* import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.coroutines.future.asCompletableFuture
import kotlinx.coroutines.future.await
import java.net.http.* import java.net.http.*
import java.nio.* import java.nio.ByteBuffer
import java.time.* import java.time.Duration
import java.util.* import java.util.*
import java.util.concurrent.* import java.util.concurrent.CompletionStage
import kotlin.coroutines.* import kotlin.coroutines.CoroutineContext
import kotlin.text.String import kotlin.text.String
import kotlin.text.toByteArray import kotlin.text.toByteArray
...@@ -92,9 +96,11 @@ internal class JavaHttpWebSocket( ...@@ -92,9 +96,11 @@ internal class JavaHttpWebSocket(
FrameType.TEXT -> { FrameType.TEXT -> {
webSocket.sendText(String(frame.data), frame.fin).await() webSocket.sendText(String(frame.data), frame.fin).await()
} }
FrameType.BINARY -> { FrameType.BINARY -> {
webSocket.sendBinary(frame.buffer, frame.fin).await() webSocket.sendBinary(frame.buffer, frame.fin).await()
} }
FrameType.CLOSE -> { FrameType.CLOSE -> {
val data = buildPacket { writeFully(frame.data) } val data = buildPacket { writeFully(frame.data) }
val code = data.readShort().toInt() val code = data.readShort().toInt()
...@@ -103,9 +109,11 @@ internal class JavaHttpWebSocket( ...@@ -103,9 +109,11 @@ internal class JavaHttpWebSocket(
socketJob.complete() socketJob.complete()
return@launch return@launch
} }
FrameType.PING -> { FrameType.PING -> {
webSocket.sendPing(frame.buffer).await() webSocket.sendPing(frame.buffer).await()
} }
FrameType.PONG -> { FrameType.PONG -> {
webSocket.sendPong(frame.buffer).await() webSocket.sendPong(frame.buffer).await()
} }
...@@ -153,11 +161,15 @@ internal class JavaHttpWebSocket( ...@@ -153,11 +161,15 @@ internal class JavaHttpWebSocket(
} }
var status = HttpStatusCode.SwitchingProtocols var status = HttpStatusCode.SwitchingProtocols
var headers: Headers
try { try {
webSocket = builder.buildAsync(requestData.url.toURI(), this).await() webSocket = builder.buildAsync(requestData.url.toURI(), this).await()
val protocol = webSocket.subprotocol?.takeIf { it.isNotEmpty() }
headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty
} catch (cause: WebSocketHandshakeException) { } catch (cause: WebSocketHandshakeException) {
if (cause.response.statusCode() == HttpStatusCode.Unauthorized.value) { if (cause.response.statusCode() == HttpStatusCode.Unauthorized.value) {
status = HttpStatusCode.Unauthorized status = HttpStatusCode.Unauthorized
headers = headersOf(cause.response.headers().map())
} else { } else {
throw cause throw cause
} }
...@@ -166,7 +178,7 @@ internal class JavaHttpWebSocket( ...@@ -166,7 +178,7 @@ internal class JavaHttpWebSocket(
return HttpResponseData( return HttpResponseData(
status, status,
requestTime, requestTime,
Headers.Empty, headers,
HttpProtocolVersion.HTTP_1_1, HttpProtocolVersion.HTTP_1_1,
this, this,
callContext callContext
...@@ -217,3 +229,11 @@ internal class JavaHttpWebSocket( ...@@ -217,3 +229,11 @@ internal class JavaHttpWebSocket(
socketJob.cancel() socketJob.cancel()
} }
} }
private fun headersOf(map: Map<String, List<String>>): Headers = object : Headers {
override val caseInsensitiveName: Boolean = true
override fun getAll(name: String): List<String>? = map[name]
override fun names(): Set<String> = map.keys
override fun entries(): Set<Map.Entry<String, List<String>>> = map.entries
override fun isEmpty(): Boolean = map.isEmpty()
}
...@@ -311,6 +311,25 @@ class WebSocketTest : ClientLoader() { ...@@ -311,6 +311,25 @@ class WebSocketTest : ClientLoader() {
} }
} }
@Test
fun testResponseContainsSecWebsocketProtocolHeader() = clientTests(except(ENGINES_WITHOUT_WS)) {
config {
install(WebSockets)
}
test { client ->
val session = client.webSocketSession("$TEST_WEBSOCKET_SERVER/websockets/sub-protocol") {
header(HttpHeaders.SecWebSocketProtocol, "test-protocol")
}
try {
assertEquals(session.call.response.headers[HttpHeaders.SecWebSocketProtocol], "test-protocol")
} finally {
session.close()
}
}
}
@Test @Test
fun testIncomingOverflow() = clientTests(except(ENGINES_WITHOUT_WS)) { fun testIncomingOverflow() = clientTests(except(ENGINES_WITHOUT_WS)) {
config { config {
......
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать