Коммит 4644d060 создал по автору dmitrievanthony's avatar dmitrievanthony Зафиксировано автором Anton Dmitriev
Просмотр файлов

Add request, connect and socket timeouts to HTTP client.

владелец bebe6206
...@@ -6,15 +6,15 @@ package io.ktor.client.engine.android ...@@ -6,15 +6,15 @@ package io.ktor.client.engine.android
import io.ktor.client.call.* import io.ktor.client.call.*
import io.ktor.client.engine.* import io.ktor.client.engine.*
import io.ktor.client.features.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.client.utils.* import io.ktor.client.utils.*
import io.ktor.http.* import io.ktor.http.*
import io.ktor.http.content.* import io.ktor.http.content.*
import io.ktor.util.cio.*
import io.ktor.util.date.* import io.ktor.util.date.*
import kotlinx.coroutines.*
import io.ktor.utils.io.* import io.ktor.utils.io.*
import io.ktor.utils.io.jvm.javaio.* import io.ktor.utils.io.jvm.javaio.*
import kotlinx.coroutines.*
import java.io.* import java.io.*
import java.net.* import java.net.*
import javax.net.ssl.* import javax.net.ssl.*
...@@ -32,6 +32,8 @@ class AndroidClientEngine(override val config: AndroidEngineConfig) : HttpClient ...@@ -32,6 +32,8 @@ class AndroidClientEngine(override val config: AndroidEngineConfig) : HttpClient
) )
} }
override val supportedCapabilities = setOf(HttpTimeout)
override suspend fun execute(data: HttpRequestData): HttpResponseData { override suspend fun execute(data: HttpRequestData): HttpResponseData {
val callContext = callContext() val callContext = callContext()
...@@ -46,6 +48,8 @@ class AndroidClientEngine(override val config: AndroidEngineConfig) : HttpClient ...@@ -46,6 +48,8 @@ class AndroidClientEngine(override val config: AndroidEngineConfig) : HttpClient
connectTimeout = config.connectTimeout connectTimeout = config.connectTimeout
readTimeout = config.socketTimeout readTimeout = config.socketTimeout
setupTimeoutAttributes(data)
if (this is HttpsURLConnection) { if (this is HttpsURLConnection) {
config.sslManager(this) config.sslManager(this)
} }
...@@ -76,10 +80,10 @@ class AndroidClientEngine(override val config: AndroidEngineConfig) : HttpClient ...@@ -76,10 +80,10 @@ class AndroidClientEngine(override val config: AndroidEngineConfig) : HttpClient
} }
} }
connection.connect() connection.timeoutAwareConnect(data)
val statusCode = HttpStatusCode(connection.responseCode, connection.responseMessage) val statusCode = HttpStatusCode(connection.responseCode, connection.responseMessage)
val content: ByteReadChannel = connection.content(callContext) val content: ByteReadChannel = connection.content(callContext, data)
val headerFields: MutableMap<String?, MutableList<String>> = connection.headerFields val headerFields: MutableMap<String?, MutableList<String>> = connection.headerFields
val version: HttpProtocolVersion = HttpProtocolVersion.HTTP_1_1 val version: HttpProtocolVersion = HttpProtocolVersion.HTTP_1_1
...@@ -117,9 +121,3 @@ internal suspend fun OutgoingContent.writeTo( ...@@ -117,9 +121,3 @@ internal suspend fun OutgoingContent.writeTo(
else -> throw UnsupportedContentTypeException(this) else -> throw UnsupportedContentTypeException(this)
} }
} }
internal fun HttpURLConnection.content(callScope: CoroutineContext): ByteReadChannel = try {
inputStream?.buffered()
} catch (_: IOException) {
errorStream?.buffered()
}?.toByteReadChannel(context = callScope, pool = KtorDefaultPool) ?: ByteReadChannel.Empty
/*
* Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/
package io.ktor.client.engine.android
import io.ktor.client.features.*
import io.ktor.client.request.*
import io.ktor.util.cio.*
import io.ktor.utils.io.*
import io.ktor.utils.io.jvm.javaio.*
import kotlinx.coroutines.*
import java.io.*
import java.net.*
import kotlin.coroutines.*
/**
* Setup [HttpURLConnection] timeout configuration using [HttpTimeout.HttpTimeoutCapabilityConfiguration] as a source.
*/
internal fun HttpURLConnection.setupTimeoutAttributes(requestData: HttpRequestData) {
requestData.getCapabilityOrNull(HttpTimeout)?.let { timeoutAttributes ->
timeoutAttributes.connectTimeoutMillis?.let { connectTimeout = convertLongTimeoutToIntWithInfiniteAsZero(it) }
timeoutAttributes.socketTimeoutMillis?.let { readTimeout = convertLongTimeoutToIntWithInfiniteAsZero(it) }
setupRequestTimeoutAttributes(timeoutAttributes)
}
}
/**
* Update [HttpURLConnection] timeout configuration to support request timeout. Required to support blocking
* [HttpURLConnection.connect] call.
*/
private fun HttpURLConnection.setupRequestTimeoutAttributes(
timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration
) {
// Android performs blocking connect call, so we need to add an upper bound on the call time.
timeoutAttributes.requestTimeoutMillis?.let { requestTimeout ->
if (requestTimeout == HttpTimeout.INFINITE_TIMEOUT_MS) return@let
if (connectTimeout == 0 || connectTimeout > requestTimeout) {
connectTimeout = convertLongTimeoutToIntWithInfiniteAsZero(requestTimeout)
}
}
}
/**
* Call [HttpURLConnection.connect] catching [SocketTimeoutException] and returning [HttpSocketTimeoutException] instead
* of it. If request timeout happens earlier [HttpRequestTimeoutException] will be thrown.
*/
internal suspend fun HttpURLConnection.timeoutAwareConnect(request: HttpRequestData) {
try {
connect()
} catch (cause: Throwable) {
// Allow to throw request timeout cancellation exception instead of connect timeout exception if needed.
yield()
throw when (cause) {
is SocketTimeoutException -> HttpConnectTimeoutException(request)
else -> cause
}
}
}
/**
* Establish connection and return correspondent [ByteReadChannel].
*/
internal fun HttpURLConnection.content(callContext: CoroutineContext, request: HttpRequestData): ByteReadChannel = try {
inputStream?.buffered()
} catch (_: IOException) {
errorStream?.buffered()
}?.toByteReadChannel(
context = callContext,
pool = KtorDefaultPool
)?.let { CoroutineScope(callContext).mapEngineExceptions(it, request) } ?: ByteReadChannel.Empty
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package io.ktor.client.engine.apache package io.ktor.client.engine.apache
import io.ktor.client.engine.* import io.ktor.client.engine.*
import io.ktor.client.features.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.client.utils.* import io.ktor.client.utils.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
...@@ -25,13 +26,15 @@ internal class ApacheEngine(override val config: ApacheEngineConfig) : HttpClien ...@@ -25,13 +26,15 @@ internal class ApacheEngine(override val config: ApacheEngineConfig) : HttpClien
) )
} }
override val supportedCapabilities = setOf(HttpTimeout)
private val engine: CloseableHttpAsyncClient = prepareClient().apply { start() } private val engine: CloseableHttpAsyncClient = prepareClient().apply { start() }
override suspend fun execute(data: HttpRequestData): HttpResponseData { override suspend fun execute(data: HttpRequestData): HttpResponseData {
val callContext = callContext() val callContext = callContext()
val apacheRequest = ApacheRequestProducer(data, config, callContext) val apacheRequest = ApacheRequestProducer(data, config, callContext)
return engine.sendRequest(apacheRequest, callContext) return engine.sendRequest(apacheRequest, callContext, data)
} }
override fun close() { override fun close() {
......
...@@ -4,21 +4,24 @@ ...@@ -4,21 +4,24 @@
package io.ktor.client.engine.apache package io.ktor.client.engine.apache
import io.ktor.client.features.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.http.* import io.ktor.http.*
import io.ktor.util.date.* import io.ktor.util.date.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
import org.apache.http.concurrent.* import org.apache.http.concurrent.*
import org.apache.http.impl.nio.client.* import org.apache.http.impl.nio.client.*
import java.net.*
import kotlin.coroutines.* import kotlin.coroutines.*
internal suspend fun CloseableHttpAsyncClient.sendRequest( internal suspend fun CloseableHttpAsyncClient.sendRequest(
request: ApacheRequestProducer, request: ApacheRequestProducer,
callContext: CoroutineContext callContext: CoroutineContext,
requestData: HttpRequestData
): HttpResponseData = suspendCancellableCoroutine { continuation -> ): HttpResponseData = suspendCancellableCoroutine { continuation ->
val requestTime = GMTDate() val requestTime = GMTDate()
val consumer = ApacheResponseConsumerDispatching(callContext) { rawResponse, body -> val consumer = ApacheResponseConsumerDispatching(callContext, requestData) { rawResponse, body ->
val statusLine = rawResponse.statusLine val statusLine = rawResponse.statusLine
val status = HttpStatusCode(statusLine.statusCode, statusLine.reasonPhrase) val status = HttpStatusCode(statusLine.statusCode, statusLine.reasonPhrase)
...@@ -34,7 +37,15 @@ internal suspend fun CloseableHttpAsyncClient.sendRequest( ...@@ -34,7 +37,15 @@ internal suspend fun CloseableHttpAsyncClient.sendRequest(
val callback = object : FutureCallback<Unit> { val callback = object : FutureCallback<Unit> {
override fun failed(exception: Exception) { override fun failed(exception: Exception) {
callContext.cancel() val mappedCause = when {
exception is ConnectException && exception.isTimeoutException() -> HttpConnectTimeoutException(
requestData
)
exception is SocketTimeoutException -> HttpSocketTimeoutException(requestData)
else -> exception
}
callContext.cancel(CancellationException("Failed to execute request", mappedCause))
continuation.cancel(exception) continuation.cancel(exception)
} }
...@@ -46,5 +57,10 @@ internal suspend fun CloseableHttpAsyncClient.sendRequest( ...@@ -46,5 +57,10 @@ internal suspend fun CloseableHttpAsyncClient.sendRequest(
} }
} }
execute(request, consumer, callback) execute(request, consumer, callback).apply {
// We need to cancel Apache future if it's not needed anymore.
continuation.invokeOnCancellation {
cancel(true)
}
}
} }
...@@ -6,6 +6,7 @@ package io.ktor.client.engine.apache ...@@ -6,6 +6,7 @@ package io.ktor.client.engine.apache
import io.ktor.client.call.* import io.ktor.client.call.*
import io.ktor.client.engine.* import io.ktor.client.engine.*
import io.ktor.client.features.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.client.utils.* import io.ktor.client.utils.*
import io.ktor.http.* import io.ktor.http.*
...@@ -148,6 +149,7 @@ internal class ApacheRequestProducer( ...@@ -148,6 +149,7 @@ internal class ApacheRequestProducer(
.setConnectTimeout(connectTimeout) .setConnectTimeout(connectTimeout)
.setConnectionRequestTimeout(connectionRequestTimeout) .setConnectionRequestTimeout(connectionRequestTimeout)
.customRequest() .customRequest()
.setupTimeoutAttributes(requestData)
.build() .build()
} }
...@@ -183,9 +185,16 @@ internal class ApacheRequestProducer( ...@@ -183,9 +185,16 @@ internal class ApacheRequestProducer(
} }
private fun ByteBuffer.recycle() { private fun ByteBuffer.recycle() {
if (requestData.body is OutgoingContent.WriteChannelContent || requestData.body is OutgoingContent.ReadChannelContent) { if (requestData.body is OutgoingContent.WriteChannelContent ||
requestData.body is OutgoingContent.ReadChannelContent) {
HttpClientDefaultPool.recycle(this) HttpClientDefaultPool.recycle(this)
} }
} }
}
private fun RequestConfig.Builder.setupTimeoutAttributes(requestData: HttpRequestData): RequestConfig.Builder = also {
requestData.getCapabilityOrNull(HttpTimeout)?.let { timeoutAttributes ->
timeoutAttributes.connectTimeoutMillis?.let { setConnectTimeout(convertLongTimeoutToIntWithInfiniteAsZero(it)) }
timeoutAttributes.socketTimeoutMillis?.let { setSocketTimeout(convertLongTimeoutToIntWithInfiniteAsZero(it)) }
}
} }
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
package io.ktor.client.engine.apache package io.ktor.client.engine.apache
import io.ktor.client.features.*
import io.ktor.client.request.*
import io.ktor.utils.io.* import io.ktor.utils.io.*
import kotlinx.atomicfu.* import kotlinx.atomicfu.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
...@@ -11,11 +13,13 @@ import org.apache.http.* ...@@ -11,11 +13,13 @@ import org.apache.http.*
import org.apache.http.nio.* import org.apache.http.nio.*
import org.apache.http.nio.protocol.* import org.apache.http.nio.protocol.*
import org.apache.http.protocol.* import org.apache.http.protocol.*
import java.net.*
import java.nio.* import java.nio.*
import kotlin.coroutines.* import kotlin.coroutines.*
internal class ApacheResponseConsumerDispatching( internal class ApacheResponseConsumerDispatching(
callContext: CoroutineContext, callContext: CoroutineContext,
private val requestData: HttpRequestData?,
private val block: (HttpResponse, ByteReadChannel) -> Unit private val block: (HttpResponse, ByteReadChannel) -> Unit
) : HttpAsyncResponseConsumer<Unit>, CoroutineScope { ) : HttpAsyncResponseConsumer<Unit>, CoroutineScope {
private val interestController = InterestControllerHolder() private val interestController = InterestControllerHolder()
...@@ -66,7 +70,9 @@ internal class ApacheResponseConsumerDispatching( ...@@ -66,7 +70,9 @@ internal class ApacheResponseConsumerDispatching(
// So we start its execution here, and it should suspend at [waitForDecoder] invocation. // So we start its execution here, and it should suspend at [waitForDecoder] invocation.
processLoop(Result.failure(IllegalStateException("The coroutine shouldn't be suspended at this point yet."))) processLoop(Result.failure(IllegalStateException("The coroutine shouldn't be suspended at this point yet.")))
check(decoderWaiter != null) { "Writer coroutine should suspend until decoder available." } check(!coroutineContext.isActive || decoderWaiter != null) {
"Writer coroutine should suspend until decoder available."
}
job.invokeOnCompletion { cause -> job.invokeOnCompletion { cause ->
jobCompletionCause.value = cause jobCompletionCause.value = cause
...@@ -114,9 +120,15 @@ internal class ApacheResponseConsumerDispatching( ...@@ -114,9 +120,15 @@ internal class ApacheResponseConsumerDispatching(
} while (dispatcher.hasTasks()) } while (dispatcher.hasTasks())
} }
override fun failed(ex: Exception) { override fun failed(cause: Exception) {
job.cancel(CancellationException("Failed to execute request", ex)) val mappedCause = when {
processLoop(Result.failure(ex)) cause is ConnectException && cause.isTimeoutException() -> HttpConnectTimeoutException(requestData!!)
cause is SocketTimeoutException -> HttpSocketTimeoutException(requestData!!)
else -> cause
}
job.cancel(CancellationException("Failed to execute request", mappedCause))
processLoop(Result.failure(cause))
} }
override fun cancel(): Boolean { override fun cancel(): Boolean {
......
/*
* Copyright 2014-2020 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/
package io.ktor.client.engine.apache
import java.net.*
/**
* Checks the message of the exception and identifies timeout exception by it.
*/
internal fun ConnectException.isTimeoutException() = message?.contains("Timeout connecting") ?: false
...@@ -56,7 +56,7 @@ class ConsumerTest : CoroutineScope { ...@@ -56,7 +56,7 @@ class ConsumerTest : CoroutineScope {
@Test @Test
fun testCreating() { fun testCreating() {
ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
}.responseCompleted(BasicHttpContext()) }.responseCompleted(BasicHttpContext())
...@@ -64,7 +64,7 @@ class ConsumerTest : CoroutineScope { ...@@ -64,7 +64,7 @@ class ConsumerTest : CoroutineScope {
@Test @Test
fun smokeTest() { fun smokeTest() {
val consumer = ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> val consumer = ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
} }
...@@ -87,7 +87,7 @@ class ConsumerTest : CoroutineScope { ...@@ -87,7 +87,7 @@ class ConsumerTest : CoroutineScope {
@Test @Test
fun emptyContent() { fun emptyContent() {
val consumer = ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> val consumer = ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
} }
...@@ -110,7 +110,7 @@ class ConsumerTest : CoroutineScope { ...@@ -110,7 +110,7 @@ class ConsumerTest : CoroutineScope {
// for some response kinds (HEAD, status NoContent as so on) consumeContent is not called // for some response kinds (HEAD, status NoContent as so on) consumeContent is not called
// so we have completed immediately after response received // so we have completed immediately after response received
val consumer = ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> val consumer = ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
} }
...@@ -126,7 +126,7 @@ class ConsumerTest : CoroutineScope { ...@@ -126,7 +126,7 @@ class ConsumerTest : CoroutineScope {
@Test @Test
fun consumeBeforeResponseReceived() { fun consumeBeforeResponseReceived() {
val consumer = ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> val consumer = ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
} }
...@@ -153,7 +153,7 @@ class ConsumerTest : CoroutineScope { ...@@ -153,7 +153,7 @@ class ConsumerTest : CoroutineScope {
@Test @Test
fun suspendSmokeTest() { fun suspendSmokeTest() {
val consumer = ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> val consumer = ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
} }
...@@ -198,7 +198,7 @@ class ConsumerTest : CoroutineScope { ...@@ -198,7 +198,7 @@ class ConsumerTest : CoroutineScope {
@Test @Test
fun integrationTest() { fun integrationTest() {
val consumer = ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> val consumer = ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
} }
...@@ -258,7 +258,7 @@ class ConsumerTest : CoroutineScope { ...@@ -258,7 +258,7 @@ class ConsumerTest : CoroutineScope {
consumerCrc.cancel() consumerCrc.cancel()
consumer.consumeContent(decoder, ioControl) consumer.consumeContent(decoder, ioControl)
} else { } else {
check(decoder.isCompleted) { "Decoder expected to be completed."} check(decoder.isCompleted) { "Decoder expected to be completed." }
} }
assertEquals(producerCrc.await(), consumerCrc.await()) assertEquals(producerCrc.await(), consumerCrc.await())
...@@ -267,7 +267,7 @@ class ConsumerTest : CoroutineScope { ...@@ -267,7 +267,7 @@ class ConsumerTest : CoroutineScope {
@Test @Test
fun lastChunkReadTest() { fun lastChunkReadTest() {
val consumer = ApacheResponseConsumerDispatching(coroutineContext) { response, channel -> val consumer = ApacheResponseConsumerDispatching(coroutineContext, null) { response, channel ->
this.receivedResponse = response this.receivedResponse = response
this.channel = channel this.channel = channel
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package io.ktor.client.engine.cio package io.ktor.client.engine.cio
import io.ktor.client.engine.* import io.ktor.client.engine.*
import io.ktor.client.features.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.client.utils.* import io.ktor.client.utils.*
import io.ktor.http.* import io.ktor.http.*
...@@ -19,6 +20,8 @@ internal class CIOEngine(override val config: CIOEngineConfig) : HttpClientEngin ...@@ -19,6 +20,8 @@ internal class CIOEngine(override val config: CIOEngineConfig) : HttpClientEngin
override val dispatcher by lazy { Dispatchers.clientDispatcher(config.threadsCount, "ktor-cio-dispatcher") } override val dispatcher by lazy { Dispatchers.clientDispatcher(config.threadsCount, "ktor-cio-dispatcher") }
override val supportedCapabilities = setOf(HttpTimeout)
private val endpoints = ConcurrentHashMap<String, Endpoint>() private val endpoints = ConcurrentHashMap<String, Endpoint>()
@UseExperimental(InternalCoroutinesApi::class) @UseExperimental(InternalCoroutinesApi::class)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package io.ktor.client.engine.cio package io.ktor.client.engine.cio
import io.ktor.client.engine.* import io.ktor.client.engine.*
import io.ktor.client.features.*
import io.ktor.network.tls.* import io.ktor.network.tls.*
/** /**
...@@ -67,10 +68,15 @@ class EndpointConfig { ...@@ -67,10 +68,15 @@ class EndpointConfig {
*/ */
var connectTimeout: Long = 5000 var connectTimeout: Long = 5000
/**
* Socket timeout in millis.
*/
val socketTimeout: Long = HttpTimeout.INFINITE_TIMEOUT_MS
/** /**
* Maximum number of connection attempts. * Maximum number of connection attempts.
*/ */
var connectRetryAttempts: Int = 5 var connectRetryAttempts: Int = 1
/** /**
* Allow socket to close output channel immediately on writing completion (TCP connection half close). * Allow socket to close output channel immediately on writing completion (TCP connection half close).
......
...@@ -7,6 +7,7 @@ package io.ktor.client.engine.cio ...@@ -7,6 +7,7 @@ package io.ktor.client.engine.cio
import io.ktor.network.selector.* import io.ktor.network.selector.*
import io.ktor.network.sockets.* import io.ktor.network.sockets.*
import io.ktor.network.sockets.Socket import io.ktor.network.sockets.Socket
import io.ktor.network.sockets.SocketOptions
import kotlinx.coroutines.sync.* import kotlinx.coroutines.sync.*
import java.net.* import java.net.*
...@@ -16,10 +17,13 @@ internal class ConnectionFactory( ...@@ -16,10 +17,13 @@ internal class ConnectionFactory(
) { ) {
private val semaphore = Semaphore(maxConnectionsCount) private val semaphore = Semaphore(maxConnectionsCount)
suspend fun connect(address: InetSocketAddress): Socket { suspend fun connect(
address: InetSocketAddress,
configuration: SocketOptions.TCPClientSocketOptions.() -> Unit = {}
): Socket {
semaphore.acquire() semaphore.acquire()
return try { return try {
aSocket(selector).tcpNoDelay().tcp().connect(address) aSocket(selector).tcpNoDelay().tcp().connect(address, configuration)
} catch (cause: Throwable) { } catch (cause: Throwable) {
// a failure or cancellation // a failure or cancellation
semaphore.release() semaphore.release()
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
package io.ktor.client.engine.cio package io.ktor.client.engine.cio
import io.ktor.client.features.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.network.sockets.* import io.ktor.network.sockets.*
import io.ktor.network.sockets.Socket import io.ktor.network.sockets.Socket
...@@ -89,23 +90,12 @@ internal class Endpoint( ...@@ -89,23 +90,12 @@ internal class Endpoint(
): Job = launch(task.context + CoroutineName("DedicatedRequest")) { ): Job = launch(task.context + CoroutineName("DedicatedRequest")) {
val (request, response, callContext) = task val (request, response, callContext) = task
try { try {
val connection = connect() val connection = connect(request)
val input = connection.openReadChannel() val input = this@Endpoint.mapEngineExceptions(connection.openReadChannel(), task.request)
val output = connection.openWriteChannel() val output = this@Endpoint.mapEngineExceptions(connection.openWriteChannel(), task.request)
val requestTime = GMTDate() val requestTime = GMTDate()
val timeout = config.requestTimeout
val responseData = if (timeout == 0L) {
request.write(output.wrap(callContext, config.endpoint.allowHalfClose), callContext, overProxy)
readResponse(requestTime, request, input, output, callContext)
} else {
withTimeout(timeout) {
request.write(output.wrap(callContext, config.endpoint.allowHalfClose), callContext, overProxy)
readResponse(requestTime, request, input, output, callContext)
}
}
callContext[Job]!!.invokeOnCompletion { cause -> callContext[Job]!!.invokeOnCompletion { cause ->
try { try {
input.cancel(cause) input.cancel(cause)
...@@ -116,9 +106,25 @@ internal class Endpoint( ...@@ -116,9 +106,25 @@ internal class Endpoint(
} }
} }
val timeout = config.requestTimeout
val writeRequestAndReadResponse: suspend CoroutineScope.() -> HttpResponseData = {
request.write(output.wrap(callContext, config.endpoint.allowHalfClose), callContext, overProxy)
readResponse(requestTime, request, input, output, callContext)
}
val responseData = if (timeout == HttpTimeout.INFINITE_TIMEOUT_MS) {
writeRequestAndReadResponse()
} else {
withTimeout(timeout, writeRequestAndReadResponse)
}
response.resume(responseData) response.resume(responseData)
} catch (cause: Throwable) { } catch (cause: Throwable) {
response.resumeWithException(cause) val mappedException = when (cause.rootCause) {
is SocketTimeoutException -> HttpSocketTimeoutException(task.request)
else -> cause
}
response.resumeWithException(mappedException)
} }
} }
...@@ -136,9 +142,10 @@ internal class Endpoint( ...@@ -136,9 +142,10 @@ internal class Endpoint(
pipeline.pipelineContext.invokeOnCompletion { releaseConnection() } pipeline.pipelineContext.invokeOnCompletion { releaseConnection() }
} }
private suspend fun connect(): Socket { private suspend fun connect(requestData: HttpRequestData? = null): Socket {
val retryAttempts = config.endpoint.connectRetryAttempts val retryAttempts = config.endpoint.connectRetryAttempts
val connectTimeout = config.endpoint.connectTimeout val (connectTimeout, socketTimeout) = retrieveTimeouts(requestData)
var timeoutFails = 0
connections.incrementAndGet() connections.incrementAndGet()
...@@ -148,8 +155,23 @@ internal class Endpoint( ...@@ -148,8 +155,23 @@ internal class Endpoint(
if (address.isUnresolved) throw UnresolvedAddressException() if (address.isUnresolved) throw UnresolvedAddressException()
val connection = withTimeoutOrNull(connectTimeout) { connectionFactory.connect(address) } val connect: suspend CoroutineScope.() -> Socket = {
?: return@repeat connectionFactory.connect(address) {
this.socketTimeout = socketTimeout
}
}
val connection = when (connectTimeout) {
HttpTimeout.INFINITE_TIMEOUT_MS -> connect()
else -> {
val connection = withTimeoutOrNull(connectTimeout, connect)
if (connection == null) {
timeoutFails++
return@repeat
}
connection
}
}
if (!secure) return@connect connection if (!secure) return@connect connection
...@@ -179,9 +201,30 @@ internal class Endpoint( ...@@ -179,9 +201,30 @@ internal class Endpoint(
} }
connections.decrementAndGet() connections.decrementAndGet()
throw FailToConnectException()
throw getTimeoutException(retryAttempts, timeoutFails, requestData!!)
} }
/**
* Defines exact type of exception based on [retryAttempts] and [timeoutFails].
*/
private fun getTimeoutException(retryAttempts: Int, timeoutFails: Int, request: HttpRequestData) =
when (timeoutFails) {
retryAttempts -> HttpConnectTimeoutException(request)
else -> FailToConnectException()
}
/**
* Take timeout attributes from [config] and [HttpTimeout.HttpTimeoutCapabilityConfiguration] and returns pair of
* connect timeout and socket timeout to be applied.
*/
private fun retrieveTimeouts(requestData: HttpRequestData?): Pair<Long, Long> =
requestData?.getCapabilityOrNull(HttpTimeout)?.let { timeoutAttributes ->
val socketTimeout = timeoutAttributes.socketTimeoutMillis ?: config.endpoint.socketTimeout
val connectTimeout = timeoutAttributes.connectTimeoutMillis ?: config.endpoint.connectTimeout
return connectTimeout to socketTimeout
} ?: config.endpoint.connectTimeout to config.endpoint.socketTimeout
private fun releaseConnection() { private fun releaseConnection() {
connectionFactory.release() connectionFactory.release()
connections.decrementAndGet() connections.decrementAndGet()
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
package io.ktor.client.engine.cio package io.ktor.client.engine.cio
import io.ktor.client.features.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.http.* import io.ktor.http.*
import io.ktor.util.date.* import io.ktor.util.date.*
...@@ -18,10 +19,17 @@ internal data class RequestTask( ...@@ -18,10 +19,17 @@ internal data class RequestTask(
internal fun RequestTask.requiresDedicatedConnection(): Boolean = listOf(request.headers, request.body.headers).any { internal fun RequestTask.requiresDedicatedConnection(): Boolean = listOf(request.headers, request.body.headers).any {
it[HttpHeaders.Connection] == "close" || it.contains(HttpHeaders.Upgrade) it[HttpHeaders.Connection] == "close" || it.contains(HttpHeaders.Upgrade)
} || request.method !in listOf(HttpMethod.Get, HttpMethod.Head) } || request.method !in listOf(HttpMethod.Get, HttpMethod.Head) || containsCustomTimeouts()
internal data class ConnectionResponseTask( internal data class ConnectionResponseTask(
val requestTime: GMTDate, val requestTime: GMTDate,
val task: RequestTask val task: RequestTask
) )
/**
* Return true if request task contains timeout attributes specified using [HttpTimeout] feature.
*/
private fun RequestTask.containsCustomTimeouts() =
request.getCapabilityOrNull(HttpTimeout)?.let {
it.connectTimeoutMillis != null || it.socketTimeoutMillis != null
} == true
...@@ -125,6 +125,8 @@ class HttpClient( ...@@ -125,6 +125,8 @@ class HttpClient(
} }
with(userConfig) { with(userConfig) {
config.install(HttpRequestLifecycle)
if (useDefaultTransformers) { if (useDefaultTransformers) {
config.install(HttpPlainText) config.install(HttpPlainText)
config.install("DefaultTransformers") { defaultTransformers() } config.install("DefaultTransformers") { defaultTransformers() }
...@@ -161,6 +163,13 @@ class HttpClient( ...@@ -161,6 +163,13 @@ class HttpClient(
suspend fun execute(builder: HttpRequestBuilder): HttpClientCall = suspend fun execute(builder: HttpRequestBuilder): HttpClientCall =
requestPipeline.execute(builder, builder.body) as HttpClientCall requestPipeline.execute(builder, builder.body) as HttpClientCall
/**
* Check if the specified [capability] is supported by this client.
*/
fun isSupported(capability: HttpClientEngineCapability<*>): Boolean {
return engine.supportedCapabilities.contains(capability)
}
/** /**
* Returns a new [HttpClient] copying this client configuration, * Returns a new [HttpClient] copying this client configuration,
* and additionally configured by the [block] parameter. * and additionally configured by the [block] parameter.
......
...@@ -9,6 +9,7 @@ import io.ktor.client.call.* ...@@ -9,6 +9,7 @@ import io.ktor.client.call.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.http.* import io.ktor.http.*
import io.ktor.util.* import io.ktor.util.*
import io.ktor.util.pipeline.*
import io.ktor.utils.io.core.* import io.ktor.utils.io.core.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlin.coroutines.* import kotlin.coroutines.*
...@@ -31,6 +32,13 @@ interface HttpClientEngine : CoroutineScope, Closeable { ...@@ -31,6 +32,13 @@ interface HttpClientEngine : CoroutineScope, Closeable {
*/ */
val config: HttpClientEngineConfig val config: HttpClientEngineConfig
/**
* Set of supported engine extensions.
*/
@KtorExperimentalAPI
val supportedCapabilities: Set<HttpClientEngineCapability<*>>
get() = emptySet()
private val closed: Boolean private val closed: Boolean
get() = !(coroutineContext[Job]?.isActive ?: false) get() = !(coroutineContext[Job]?.isActive ?: false)
...@@ -52,16 +60,11 @@ interface HttpClientEngine : CoroutineScope, Closeable { ...@@ -52,16 +60,11 @@ interface HttpClientEngine : CoroutineScope, Closeable {
}.build() }.build()
validateHeaders(requestData) validateHeaders(requestData)
checkExtensions(requestData)
val responseData = executeWithinCallContext(requestData) val responseData = executeWithinCallContext(requestData)
val call = HttpClientCall(client, requestData, responseData) val call = HttpClientCall(client, requestData, responseData)
responseData.callContext[Job]!!.invokeOnCompletion { cause ->
@Suppress("UNCHECKED_CAST")
val childContext = requestData.executionContext as CompletableJob
if (cause == null) childContext.complete() else childContext.completeExceptionally(cause)
}
proceedWith(call) proceedWith(call)
} }
} }
...@@ -81,6 +84,12 @@ interface HttpClientEngine : CoroutineScope, Closeable { ...@@ -81,6 +84,12 @@ interface HttpClientEngine : CoroutineScope, Closeable {
}.await() }.await()
} }
private fun checkExtensions(requestData: HttpRequestData) {
for (requestedExtension in requestData.requiredCapabilities) {
require(supportedCapabilities.contains(requestedExtension)) { "Engine doesn't support $requestedExtension" }
}
}
/** /**
* Create call context with the specified [parentJob] to be used during call execution in the engine. Call context * Create call context with the specified [parentJob] to be used during call execution in the engine. Call context
* inherits [coroutineContext], but overrides job and coroutine name so that call job's parent is [parentJob] and * inherits [coroutineContext], but overrides job and coroutine name so that call job's parent is [parentJob] and
...@@ -96,7 +105,6 @@ interface HttpClientEngine : CoroutineScope, Closeable { ...@@ -96,7 +105,6 @@ interface HttpClientEngine : CoroutineScope, Closeable {
} }
} }
/** /**
* Factory of [HttpClientEngine] with a specific [T] of [HttpClientEngineConfig]. * Factory of [HttpClientEngine] with a specific [T] of [HttpClientEngineConfig].
*/ */
......
/*
* Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/
package io.ktor.client.engine
import io.ktor.client.features.*
import io.ktor.util.*
import kotlin.native.concurrent.*
/**
* Key required to access capabilities.
*/
@KtorExperimentalAPI
@SharedImmutable
internal val ENGINE_CAPABILITIES_KEY = AttributeKey<MutableMap<HttpClientEngineCapability<*>, Any>>("EngineCapabilities")
/**
* Default capabilities expected to be supported by engine.
*/
@KtorExperimentalAPI
@SharedImmutable
val DEFAULT_CAPABILITIES = setOf(HttpTimeout)
/**
* Capability required by request to be supported by [HttpClientEngine] with [T] representing type of the capability
* configuration.
*/
interface HttpClientEngineCapability<T>
...@@ -42,16 +42,17 @@ class HttpRedirect { ...@@ -42,16 +42,17 @@ class HttpRedirect {
override fun prepare(block: HttpRedirect.() -> Unit): HttpRedirect = HttpRedirect().apply(block) override fun prepare(block: HttpRedirect.() -> Unit): HttpRedirect = HttpRedirect().apply(block)
override fun install(feature: HttpRedirect, scope: HttpClient) { override fun install(feature: HttpRedirect, scope: HttpClient) {
scope.feature(HttpSend)!!.intercept { origin -> scope.feature(HttpSend)!!.intercept { origin, context ->
if (feature.checkHttpMethod && origin.request.method !in ALLOWED_FOR_REDIRECT) { if (feature.checkHttpMethod && origin.request.method !in ALLOWED_FOR_REDIRECT) {
return@intercept origin return@intercept origin
} }
handleCall(origin, feature.allowHttpsDowngrade) handleCall(context, origin, feature.allowHttpsDowngrade)
} }
} }
private suspend fun Sender.handleCall( private suspend fun Sender.handleCall(
context: HttpRequestBuilder,
origin: HttpClientCall, origin: HttpClientCall,
allowHttpsDowngrade: Boolean allowHttpsDowngrade: Boolean
): HttpClientCall { ): HttpClientCall {
...@@ -64,7 +65,7 @@ class HttpRedirect { ...@@ -64,7 +65,7 @@ class HttpRedirect {
val location = call.response.headers[HttpHeaders.Location] val location = call.response.headers[HttpHeaders.Location]
val requestBuilder = HttpRequestBuilder().apply { val requestBuilder = HttpRequestBuilder().apply {
takeFrom(origin.request) takeFrom(context)
url.parameters.clear() url.parameters.clear()
location?.let { url.takeFrom(it) } location?.let { url.takeFrom(it) }
......
/*
* Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/
package io.ktor.client.features
import io.ktor.client.*
import io.ktor.client.request.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import kotlinx.coroutines.*
/**
* Client HTTP feature that sets up [HttpRequestBuilder.executionContext] and completes it when the pipeline is fully
* processed.
*/
internal class HttpRequestLifecycle {
/**
* Companion object for feature installation.
*/
companion object Feature : HttpClientFeature<Unit, HttpRequestLifecycle> {
override val key: AttributeKey<HttpRequestLifecycle> = AttributeKey("RequestLifecycle")
override fun prepare(block: Unit.() -> Unit): HttpRequestLifecycle = HttpRequestLifecycle()
override fun install(feature: HttpRequestLifecycle, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Before) {
val executionContext = Job(context.executionContext)
attachToClientEngineJob(executionContext)
try {
context.executionContext = executionContext
proceed()
} catch (cause: Throwable) {
executionContext.completeExceptionally(cause)
throw cause
} finally {
executionContext.complete()
}
}
}
}
}
/**
* Attach client engine job
*/
private fun PipelineContext<*, HttpRequestBuilder>.attachToClientEngineJob(clientEngineJob: Job) {
val handler = clientEngineJob.invokeOnCompletion { cause ->
if (cause != null) {
context.executionContext.cancel("Engine failed", cause)
} else {
(context.executionContext[Job] as CompletableJob).complete()
}
}
context.executionContext[Job]!!.invokeOnCompletion {
handler.dispose()
}
}
...@@ -14,7 +14,12 @@ import kotlinx.coroutines.* ...@@ -14,7 +14,12 @@ import kotlinx.coroutines.*
/** /**
* HttpSend pipeline interceptor function * HttpSend pipeline interceptor function
*/ */
typealias HttpSendInterceptor = suspend Sender.(HttpClientCall) -> HttpClientCall typealias HttpSendInterceptor = suspend Sender.(HttpClientCall, HttpRequestBuilder) -> HttpClientCall
/**
* HttpSend pipeline interceptor function backward compatible with previous implementation.
*/
typealias HttpSendInterceptorBackwardCompatible = suspend Sender.(HttpClientCall) -> HttpClientCall
/** /**
* This interface represents a request send pipeline interceptor chain * This interface represents a request send pipeline interceptor chain
...@@ -44,6 +49,16 @@ class HttpSend( ...@@ -44,6 +49,16 @@ class HttpSend(
interceptors += block interceptors += block
} }
/**
* Install send pipeline starter interceptor (backward compatible function).
*/
@Deprecated("Intercept with one parameter is deprecated, use both call and request builder as parameters.")
fun intercept(block: HttpSendInterceptorBackwardCompatible) {
interceptors += { call, _ ->
block(call)
}
}
/** /**
* Feature installation object * Feature installation object
*/ */
...@@ -68,7 +83,7 @@ class HttpSend( ...@@ -68,7 +83,7 @@ class HttpSend(
callChanged = false callChanged = false
passInterceptors@ for (interceptor in feature.interceptors) { passInterceptors@ for (interceptor in feature.interceptors) {
val transformed = interceptor(sender, currentCall) val transformed = interceptor(sender, currentCall, context)
if (transformed === currentCall) continue@passInterceptors if (transformed === currentCall) continue@passInterceptors
currentCall = transformed currentCall = transformed
......
/*
* Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/
package io.ktor.client.features
import io.ktor.client.*
import io.ktor.client.engine.*
import io.ktor.client.request.*
import io.ktor.util.*
import io.ktor.utils.io.errors.*
import kotlinx.coroutines.*
import kotlin.native.concurrent.*
/**
* Client HTTP timeout feature. There are no default values, so default timeouts will be taken from engine configuration
* or considered as infinite time if engine doesn't provide them.
*/
class HttpTimeout(
private val requestTimeoutMillis: Long?,
private val connectTimeoutMillis: Long?,
private val socketTimeoutMillis: Long?
) {
/**
* [HttpTimeout] extension configuration that is used during installation.
*/
class HttpTimeoutCapabilityConfiguration {
/**
* Creates a new instance of [HttpTimeoutCapabilityConfiguration].
*/
@InternalAPI
constructor(
requestTimeoutMillis: Long? = null,
connectTimeoutMillis: Long? = null,
socketTimeoutMillis: Long? = null
) {
this.requestTimeoutMillis = requestTimeoutMillis
this.connectTimeoutMillis = connectTimeoutMillis
this.socketTimeoutMillis = socketTimeoutMillis
}
/**
* Request timeout in milliseconds.
*/
var requestTimeoutMillis: Long?
set(value) {
field = checkTimeoutValue(value)
}
/**
* Connect timeout in milliseconds.
*/
var connectTimeoutMillis: Long?
set(value) {
field = checkTimeoutValue(value)
}
/**
* Socket timeout (read and write) in milliseconds.
*/
var socketTimeoutMillis: Long?
set(value) {
field = checkTimeoutValue(value)
}
internal fun build(): HttpTimeout = HttpTimeout(requestTimeoutMillis, connectTimeoutMillis, socketTimeoutMillis)
private fun checkTimeoutValue(value: Long?): Long? {
require(value == null || value > 0) {
"Only positive timeout values are allowed, for infinite timeout use HttpTimeout.INFINITE_TIMEOUT_MS"
}
return value
}
companion object {
@SharedImmutable
val key = AttributeKey<HttpTimeoutCapabilityConfiguration>("TimeoutConfiguration")
}
}
/**
* Utils method that return true if at least one timeout is configured (has not null value).
*/
private fun hasNotNullTimeouts() =
requestTimeoutMillis != null || connectTimeoutMillis != null || socketTimeoutMillis != null
/**
* Companion object for feature installation.
*/
companion object Feature : HttpClientFeature<HttpTimeoutCapabilityConfiguration, HttpTimeout>,
HttpClientEngineCapability<HttpTimeoutCapabilityConfiguration> {
override val key: AttributeKey<HttpTimeout> = AttributeKey("TimeoutFeature")
/**
* Infinite timeout in milliseconds.
*/
@SharedImmutable
const val INFINITE_TIMEOUT_MS = Long.MAX_VALUE
override fun prepare(block: HttpTimeoutCapabilityConfiguration.() -> Unit): HttpTimeout =
HttpTimeoutCapabilityConfiguration().apply(block).build()
override fun install(feature: HttpTimeout, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Before) {
var configuration = context.getCapabilityOrNull(HttpTimeout)
if (configuration == null && feature.hasNotNullTimeouts()) {
configuration = HttpTimeoutCapabilityConfiguration()
context.setCapability(HttpTimeout, configuration)
}
configuration?.apply {
connectTimeoutMillis = connectTimeoutMillis ?: feature.connectTimeoutMillis
socketTimeoutMillis = socketTimeoutMillis ?: feature.socketTimeoutMillis
requestTimeoutMillis = requestTimeoutMillis ?: feature.requestTimeoutMillis
val requestTimeout = requestTimeoutMillis ?: feature.requestTimeoutMillis
if (requestTimeout == null || requestTimeout == INFINITE_TIMEOUT_MS) return@apply
val executionContext = context.executionContext
val killer = scope.launch {
delay(requestTimeout)
executionContext.cancel(HttpRequestTimeoutException(context))
}
context.executionContext.invokeOnCompletion {
killer.cancel()
}
}
}
}
}
}
/**
* Adds timeout boundaries to the request. Requires [HttpTimeout] feature to be installed.
*/
fun HttpRequestBuilder.timeout(block: HttpTimeout.HttpTimeoutCapabilityConfiguration.() -> Unit) =
setCapability(HttpTimeout, HttpTimeout.HttpTimeoutCapabilityConfiguration().apply(block))
/**
* This exception is thrown in case request timeout exceeded.
*/
class HttpRequestTimeoutException(request: HttpRequestBuilder) :
CancellationException(
"Request timeout has been expired [url=${request.url}, request_timeout=${request.getCapabilityOrNull(
HttpTimeout
)?.requestTimeoutMillis ?: "unknown"} ms]"
)
/**
* This exception is thrown in case connect timeout exceeded.
*/
expect class HttpConnectTimeoutException(request: HttpRequestData) : IOException
/**
* This exception is thrown in case socket timeout (read or write) exceeded.
*/
expect class HttpSocketTimeoutException(request: HttpRequestData) : IOException
/**
* Convert long timeout in milliseconds to int value. To do that we need to consider [HttpTimeout.INFINITE_TIMEOUT_MS]
* as zero and convert timeout value to [Int].
*/
@InternalAPI
fun convertLongTimeoutToIntWithInfiniteAsZero(timeout: Long): Int = when {
timeout == HttpTimeout.INFINITE_TIMEOUT_MS -> 0
timeout < Int.MIN_VALUE -> Int.MIN_VALUE
timeout > Int.MAX_VALUE -> Int.MAX_VALUE
else -> timeout.toInt()
}
@InternalAPI
fun convertLongTimeoutToLongWithInfiniteAsZero(timeout: Long): Long = when (timeout) {
HttpTimeout.INFINITE_TIMEOUT_MS -> 0L
else -> timeout
}
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать