Не подтверждена Коммит 9ba9388c создал по автору Bruce Hamilton's avatar Bruce Hamilton Зафиксировано автором GitHub
Просмотр файлов

KTOR-8210 Fix copy() and Source multipart processing (#4686)

* KTOR-8210 Use peek over copy; avoid assumptions of Source contents fully buffered

* Revert "KTOR-8210 Use peek over copy; avoid assumptions of Source contents fully buffered"

This reverts commit 37bf04a03d5678e12f405c8325717e0138a96a54.

* KTOR-8210 Fix copy() and multipart source content processing

* fixup! KTOR-8210 Fix copy() and multipart source content processing

* Increase Jetty test timeout to resolve flakiness
владелец a309ce8a
...@@ -9,6 +9,7 @@ import io.ktor.http.content.* ...@@ -9,6 +9,7 @@ import io.ktor.http.content.*
import io.ktor.utils.io.* import io.ktor.utils.io.*
import io.ktor.utils.io.core.* import io.ktor.utils.io.core.*
import kotlinx.io.* import kotlinx.io.*
import kotlinx.io.Buffer
import kotlin.contracts.* import kotlin.contracts.*
/** /**
...@@ -49,8 +50,10 @@ public fun formData(vararg values: FormPart<*>): List<PartData> { ...@@ -49,8 +50,10 @@ public fun formData(vararg values: FormPart<*>): List<PartData> {
PartData.BinaryItem({ ByteReadPacket(value) }, {}, partHeaders.build()) PartData.BinaryItem({ ByteReadPacket(value) }, {}, partHeaders.build())
} }
is Source -> { is Source -> {
partHeaders.append(HttpHeaders.ContentLength, value.remaining.toString()) if (value is Buffer) {
PartData.BinaryItem({ value.copy() }, { value.close() }, partHeaders.build()) partHeaders.append(HttpHeaders.ContentLength, value.remaining.toString())
}
PartData.BinaryItem({ value.peek() }, { value.close() }, partHeaders.build())
} }
is InputProvider -> { is InputProvider -> {
val size = value.size val size = value.size
......
...@@ -3,17 +3,22 @@ ...@@ -3,17 +3,22 @@
*/ */
import io.ktor.client.request.forms.* import io.ktor.client.request.forms.*
import io.ktor.http.Headers
import io.ktor.http.HttpHeaders
import io.ktor.test.dispatcher.* import io.ktor.test.dispatcher.*
import io.ktor.utils.io.* import io.ktor.utils.io.*
import io.ktor.utils.io.charsets.* import io.ktor.utils.io.charsets.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.test.runTest
import kotlinx.io.* import kotlinx.io.*
import kotlinx.io.files.Path
import kotlin.random.Random
import kotlin.test.* import kotlin.test.*
class MultiPartFormDataContentTest { class MultiPartFormDataContentTest {
@Test @Test
fun testMultiPartFormDataContentHasCorrectPrefix() = testSuspend { fun testMultiPartFormDataContentHasCorrectPrefix() = runTest {
val formData = MultiPartFormDataContent( val formData = MultiPartFormDataContent(
formData { formData {
append("Hello", "World") append("Hello", "World")
...@@ -33,7 +38,7 @@ class MultiPartFormDataContentTest { ...@@ -33,7 +38,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testEmptyByteReadChannel() = testSuspend { fun testEmptyByteReadChannel() = runTest {
val data = MultiPartFormDataContent( val data = MultiPartFormDataContent(
formData { formData {
append("channel", ChannelProvider { ByteReadChannel.Empty }) append("channel", ChannelProvider { ByteReadChannel.Empty })
...@@ -55,7 +60,7 @@ class MultiPartFormDataContentTest { ...@@ -55,7 +60,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testByteReadChannelWithString() = testSuspend { fun testByteReadChannelWithString() = runTest {
val content = "body" val content = "body"
val data = MultiPartFormDataContent( val data = MultiPartFormDataContent(
formData { formData {
...@@ -79,7 +84,7 @@ class MultiPartFormDataContentTest { ...@@ -79,7 +84,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testNumberQuoted() = testSuspend { fun testNumberQuoted() = runTest {
val data = MultiPartFormDataContent( val data = MultiPartFormDataContent(
formData { formData {
append("not_a_forty_two", 1337) append("not_a_forty_two", 1337)
...@@ -102,7 +107,7 @@ class MultiPartFormDataContentTest { ...@@ -102,7 +107,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testBooleanQuoted() = testSuspend { fun testBooleanQuoted() = runTest {
val data = MultiPartFormDataContent( val data = MultiPartFormDataContent(
formData { formData {
append("is_forty_two", false) append("is_forty_two", false)
...@@ -125,7 +130,7 @@ class MultiPartFormDataContentTest { ...@@ -125,7 +130,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testStringsList() = testSuspend { fun testStringsList() = runTest {
val data = MultiPartFormDataContent( val data = MultiPartFormDataContent(
formData { formData {
append("platforms[]", listOf("windows", "linux", "osx")) append("platforms[]", listOf("windows", "linux", "osx"))
...@@ -158,7 +163,7 @@ class MultiPartFormDataContentTest { ...@@ -158,7 +163,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testStringsArray() = testSuspend { fun testStringsArray() = runTest {
val data = MultiPartFormDataContent( val data = MultiPartFormDataContent(
formData { formData {
append("platforms[]", arrayOf("windows", "linux", "osx")) append("platforms[]", arrayOf("windows", "linux", "osx"))
...@@ -191,7 +196,7 @@ class MultiPartFormDataContentTest { ...@@ -191,7 +196,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testStringsListBadKey() = testSuspend { fun testStringsListBadKey() = runTest {
val attempt = { val attempt = {
MultiPartFormDataContent( MultiPartFormDataContent(
formData { formData {
...@@ -206,7 +211,7 @@ class MultiPartFormDataContentTest { ...@@ -206,7 +211,7 @@ class MultiPartFormDataContentTest {
} }
@Test @Test
fun testByteReadChannelOverBufferSize() = testSuspend { fun testByteReadChannelOverBufferSize() = runTest {
val body = ByteArray(4089) { 'k'.code.toByte() } val body = ByteArray(4089) { 'k'.code.toByte() }
val data = MultiPartFormDataContent( val data = MultiPartFormDataContent(
formData { formData {
...@@ -228,6 +233,36 @@ class MultiPartFormDataContentTest { ...@@ -228,6 +233,36 @@ class MultiPartFormDataContentTest {
) )
} }
@Test
fun testFileContentFromSource() = runTest {
val expected = "This content should appear in the multipart body."
val fileSource = try {
with(kotlinx.io.files.SystemFileSystem) {
val file = Path(kotlinx.io.files.SystemTemporaryDirectory, "temp${Random.nextInt(1000, 9999)}.txt")
sink(file).buffered().use { it.writeString(expected) }
source(file).buffered()
}
} catch (_: Throwable) {
// filesystem is not supported for web platforms (yet)
return@runTest
}
val data = MultiPartFormDataContent(
formData {
append(
key = "key",
value = fileSource,
headers = Headers.build {
append(HttpHeaders.ContentType, "text/plain")
append(HttpHeaders.ContentDisposition, "filename=\"file.txt\"")
},
)
}
)
assertTrue("File contents should be present in the multipart body.") {
data.readString().contains(expected)
}
}
private suspend fun MultiPartFormDataContent.readString(charset: Charset = Charsets.UTF_8): String { private suspend fun MultiPartFormDataContent.readString(charset: Charset = Charsets.UTF_8): String {
val bytes = readBytes() val bytes = readBytes()
return bytes.decodeToString(0, 0 + bytes.size) return bytes.decodeToString(0, 0 + bytes.size)
......
...@@ -49,8 +49,15 @@ public fun Source.readAvailable(out: kotlinx.io.Buffer): Int { ...@@ -49,8 +49,15 @@ public fun Source.readAvailable(out: kotlinx.io.Buffer): Int {
return result.toInt() return result.toInt()
} }
/**
* Returns a copy of the current buffer attached to this Source.
*/
@Deprecated(
"Use peek() or buffer.copy() instead, depending on your use case.",
ReplaceWith("peek()", "kotlinx.io.Source")
)
@OptIn(InternalIoApi::class) @OptIn(InternalIoApi::class)
public fun Source.copy(): Source = buffer.copy() public fun Source.copy(): Source = peek()
@OptIn(InternalIoApi::class) @OptIn(InternalIoApi::class)
public fun Source.readShortLittleEndian(): Short { public fun Source.readShortLittleEndian(): Short {
......
...@@ -18,7 +18,7 @@ public fun Source.inputStream(): InputStream = asInputStream() ...@@ -18,7 +18,7 @@ public fun Source.inputStream(): InputStream = asInputStream()
@OptIn(InternalIoApi::class) @OptIn(InternalIoApi::class)
public fun OutputStream.writePacket(packet: Source) { public fun OutputStream.writePacket(packet: Source) {
packet.buffer.copyTo(this) packet.transferTo(this.asSink())
} }
public fun OutputStream.writePacket(block: Sink.() -> Unit) { public fun OutputStream.writePacket(block: Sink.() -> Unit) {
......
...@@ -27,7 +27,7 @@ class JettyIdleTimeoutTest : EngineTestBase<JettyApplicationEngine, JettyApplica ...@@ -27,7 +27,7 @@ class JettyIdleTimeoutTest : EngineTestBase<JettyApplicationEngine, JettyApplica
override fun configure(configuration: JettyApplicationEngineBase.Configuration) { override fun configure(configuration: JettyApplicationEngineBase.Configuration) {
super.configure(configuration) super.configure(configuration)
configuration.idleTimeout = 10.milliseconds configuration.idleTimeout = 100.milliseconds
} }
@Test @Test
......
...@@ -27,7 +27,7 @@ class JettyIdleTimeoutTest : EngineTestBase<JettyApplicationEngine, JettyApplica ...@@ -27,7 +27,7 @@ class JettyIdleTimeoutTest : EngineTestBase<JettyApplicationEngine, JettyApplica
override fun configure(configuration: JettyApplicationEngineBase.Configuration) { override fun configure(configuration: JettyApplicationEngineBase.Configuration) {
super.configure(configuration) super.configure(configuration)
configuration.idleTimeout = 10.milliseconds configuration.idleTimeout = 100.milliseconds
} }
@Test @Test
......
...@@ -67,7 +67,7 @@ class HighLoadHttpGenerator( ...@@ -67,7 +67,7 @@ class HighLoadHttpGenerator(
private val request = RequestResponseBuilder().apply(builder).build() private val request = RequestResponseBuilder().apply(builder).build()
private val requestByteBuffer = ByteBuffer.allocateDirect(request.remaining.toInt())!!.apply { private val requestByteBuffer = ByteBuffer.allocateDirect(request.remaining.toInt())!!.apply {
request.copy().readFully(this) request.peek().readFully(this)
clear() clear()
} }
......
...@@ -120,7 +120,7 @@ abstract class SustainabilityTestSuite<TEngine : ApplicationEngine, TConfigurati ...@@ -120,7 +120,7 @@ abstract class SustainabilityTestSuite<TEngine : ApplicationEngine, TConfigurati
emptyLine() emptyLine()
}.build().use { request -> }.build().use { request ->
repeat(repeatCount) { repeat(repeatCount) {
getOutputStream().writePacket(request.copy()) getOutputStream().writePacket(request.peek())
getOutputStream().write(body) getOutputStream().write(body)
getOutputStream().flush() getOutputStream().flush()
} }
......
...@@ -7,8 +7,9 @@ package io.ktor.websocket.internals ...@@ -7,8 +7,9 @@ package io.ktor.websocket.internals
import io.ktor.utils.io.core.* import io.ktor.utils.io.core.*
import kotlinx.io.* import kotlinx.io.*
@OptIn(InternalIoApi::class)
internal fun Source.endsWith(data: ByteArray): Boolean { internal fun Source.endsWith(data: ByteArray): Boolean {
copy().apply { buffer.copy().apply {
discard(remaining - data.size) discard(remaining - data.size)
return readByteArray().contentEquals(data) return readByteArray().contentEquals(data)
} }
......
...@@ -104,7 +104,7 @@ public fun String.decodeBase64Bytes(): ByteArray = buildPacket { ...@@ -104,7 +104,7 @@ public fun String.decodeBase64Bytes(): ByteArray = buildPacket {
public fun Source.decodeBase64Bytes(): Input = buildPacket { public fun Source.decodeBase64Bytes(): Input = buildPacket {
val data = ByteArray(4) val data = ByteArray(4)
while (remaining > 0) { while (!exhausted()) {
val read = readAvailable(data) val read = readAvailable(data)
val chunk = data.foldIndexed(0) { index, result, current -> val chunk = data.foldIndexed(0) { index, result, current ->
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
package io.ktor.util package io.ktor.util
import io.ktor.utils.io.* import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import io.ktor.utils.io.pool.* import io.ktor.utils.io.pool.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
...@@ -64,8 +63,8 @@ public fun ByteReadChannel.copyToBoth(first: ByteWriteChannel, second: ByteWrite ...@@ -64,8 +63,8 @@ public fun ByteReadChannel.copyToBoth(first: ByteWriteChannel, second: ByteWrite
while (!isClosedForRead && (!first.isClosedForWrite || !second.isClosedForWrite)) { while (!isClosedForRead && (!first.isClosedForWrite || !second.isClosedForWrite)) {
readRemaining(CHUNK_BUFFER_SIZE).use { readRemaining(CHUNK_BUFFER_SIZE).use {
try { try {
first.writePacket(it.copy()) first.writePacket(it.peek())
second.writePacket(it.copy()) second.writePacket(it.peek())
} catch (cause: Throwable) { } catch (cause: Throwable) {
this@copyToBoth.cancel(cause) this@copyToBoth.cancel(cause)
first.close(cause) first.close(cause)
......
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать