Коммит 7a7144c9 создал по автору Leonid Stashevsky's avatar Leonid Stashevsky Зафиксировано автором Leonid Stashevsky
Просмотр файлов

Add client certificates support

        Closes #641
владелец 712e6119
...@@ -3,7 +3,6 @@ package io.ktor.client.engine.android ...@@ -3,7 +3,6 @@ package io.ktor.client.engine.android
import io.ktor.client.tests.* import io.ktor.client.tests.*
import org.junit.* import org.junit.*
@Ignore @Ignore
class AndroidCacheTest : CacheTest(Android) class AndroidCacheTest : CacheTest(Android)
......
...@@ -2,13 +2,10 @@ package io.ktor.client.engine.cio ...@@ -2,13 +2,10 @@ package io.ktor.client.engine.cio
import io.ktor.client.engine.* import io.ktor.client.engine.*
import io.ktor.network.tls.* import io.ktor.network.tls.*
import io.ktor.util.*
import java.security.* import java.security.*
import java.security.cert.*
import javax.net.ssl.* import javax.net.ssl.*
private val DEFAULT_RANDOM: String =
SecureRandom().algorithm.takeIf { it != "unknown" } ?: "NativePRNGNonBlocking"
/** /**
* Configuration for [CIO] client engine. * Configuration for [CIO] client engine.
*/ */
...@@ -20,7 +17,7 @@ class CIOEngineConfig : HttpClientEngineConfig() { ...@@ -20,7 +17,7 @@ class CIOEngineConfig : HttpClientEngineConfig() {
/** /**
* [https] settings. * [https] settings.
*/ */
val https: HttpsConfig = HttpsConfig() val https: TLSConfigBuilder = TLSConfigBuilder()
/** /**
* Maximum allowed connections count. * Maximum allowed connections count.
...@@ -30,7 +27,7 @@ class CIOEngineConfig : HttpClientEngineConfig() { ...@@ -30,7 +27,7 @@ class CIOEngineConfig : HttpClientEngineConfig() {
/** /**
* [https] settings. * [https] settings.
*/ */
fun https(block: HttpsConfig.() -> Unit): HttpsConfig = https.apply(block) fun https(block: TLSConfigBuilder.() -> Unit): TLSConfigBuilder = https.apply(block)
} }
/** /**
...@@ -62,25 +59,3 @@ class EndpointConfig { ...@@ -62,25 +59,3 @@ class EndpointConfig {
*/ */
var connectRetryAttempts: Int = 5 var connectRetryAttempts: Int = 5
} }
/**
* Https settings.
*/
class HttpsConfig {
/**
* Custom [X509TrustManager] to verify server authority.
*
* Use system by default.
*/
var trustManager: X509TrustManager? = null
/**
* Random nonce generation algorithm.
*/
var randomAlgorithm: String = DEFAULT_RANDOM
/**
* List of allowed [CipherSuite]s.
*/
var cipherSuites: List<CipherSuite> = CIOCipherSuites.SupportedSuites
}
...@@ -69,51 +69,50 @@ internal class ConnectionPipeline( ...@@ -69,51 +69,50 @@ internal class ConnectionPipeline(
val rawResponse = parseResponse(networkInput) val rawResponse = parseResponse(networkInput)
?: throw EOFException("Failed to parse HTTP response: unexpected EOF") ?: throw EOFException("Failed to parse HTTP response: unexpected EOF")
val callContext = task.context try {
val callContext = task.context
val method = task.request.method
val contentLength = rawResponse.headers[HttpHeaders.ContentLength]?.toString()?.toLong() ?: -1L val method = task.request.method
val transferEncoding = rawResponse.headers[HttpHeaders.TransferEncoding] val contentLength = rawResponse.headers[HttpHeaders.ContentLength]?.toString()?.toLong() ?: -1L
val chunked = transferEncoding == "chunked" val transferEncoding = rawResponse.headers[HttpHeaders.TransferEncoding]
val connectionType = ConnectionOptions.parse(rawResponse.headers[HttpHeaders.Connection]) val chunked = transferEncoding == "chunked"
val headers = CIOHeaders(rawResponse.headers) val connectionType = ConnectionOptions.parse(rawResponse.headers[HttpHeaders.Connection])
val headers = CIOHeaders(rawResponse.headers)
callContext[Job]?.invokeOnCompletion {
rawResponse.release() shouldClose = (connectionType == ConnectionOptions.Close)
}
val hasBody = (contentLength > 0 || chunked) && method != HttpMethod.Head
shouldClose = (connectionType == ConnectionOptions.Close) val responseChannel = if (hasBody) ByteChannel() else null
val hasBody = (contentLength > 0 || chunked) && method != HttpMethod.Head var skipTask: Job? = null
val responseChannel = if (hasBody) ByteChannel() else null val body: ByteReadChannel = if (responseChannel != null) {
val proxyChannel = ByteChannel()
var skipTask: Job? = null skipTask = skipCancels(responseChannel, proxyChannel)
val body: ByteReadChannel = if (responseChannel != null) { proxyChannel
val proxyChannel = ByteChannel() } else ByteReadChannel.Empty
skipTask = skipCancels(responseChannel, proxyChannel)
proxyChannel val response = CIOHttpResponse(
} else ByteReadChannel.Empty task.request, headers, requestTime,
body,
val response = CIOHttpResponse( rawResponse,
task.request, headers, requestTime, coroutineContext = callContext
body,
rawResponse,
coroutineContext = callContext
)
task.response.complete(response)
responseChannel?.use {
parseHttpBody(
contentLength,
transferEncoding,
connectionType,
networkInput,
this
) )
}
skipTask?.join() task.response.complete(response)
responseChannel?.use {
parseHttpBody(
contentLength,
transferEncoding,
connectionType,
networkInput,
this
)
}
skipTask?.join()
} finally {
rawResponse.headers.release()
}
} catch (cause: Throwable) { } catch (cause: Throwable) {
task.response.completeExceptionally(cause) task.response.completeExceptionally(cause)
} }
......
...@@ -183,13 +183,13 @@ internal class Endpoint( ...@@ -183,13 +183,13 @@ internal class Endpoint(
try { try {
with(config.https) { with(config.https) {
return@connect connection.tls( return@connect connection.tls(coroutineContext) {
coroutineContext, trustManager = this@with.trustManager
trustManager, random = this@with.random
randomAlgorithm, cipherSuites = this@with.cipherSuites
cipherSuites, serverName = this@with.serverName ?: address.hostName
address.hostName certificates += this@with.certificates
) }
} }
} catch (cause: Throwable) { } catch (cause: Throwable) {
try { try {
......
package io.ktor.client.engine.cio package io.ktor.client.engine.cio
import io.ktor.application.* import io.ktor.application.*
import io.ktor.client.*
import io.ktor.client.engine.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.client.response.* import io.ktor.client.response.*
import io.ktor.client.tests.utils.* import io.ktor.client.tests.utils.*
...@@ -115,14 +113,18 @@ class CIOHttpsTest : TestWithKtor() { ...@@ -115,14 +113,18 @@ class CIOHttpsTest : TestWithKtor() {
} }
@Test @Test
fun customDomainsTest() = clientTest(CIO) { fun customDomainsTest(): Unit = clientTest(CIO) {
val domains = listOf( val domains = listOf(
"https://google.com", "https://google.com",
"https://facebook.com", "https://facebook.com",
// "https://elster.de", "https://elster.de",
"https://freenode.net" "https://freenode.net"
) )
config {
expectSuccess = false
}
test { client -> test { client ->
domains.forEach { url -> domains.forEach { url ->
client.get<String>(url) client.get<String>(url)
...@@ -131,7 +133,7 @@ class CIOHttpsTest : TestWithKtor() { ...@@ -131,7 +133,7 @@ class CIOHttpsTest : TestWithKtor() {
} }
@Test @Test
fun repeatRequestTest() = clientTest(CIO) { fun repeatRequestTest(): Unit = clientTest(CIO) {
config { config {
followRedirects = false followRedirects = false
......
...@@ -11,5 +11,5 @@ class JettyEngineConfig : HttpClientEngineConfig() { ...@@ -11,5 +11,5 @@ class JettyEngineConfig : HttpClientEngineConfig() {
/** /**
* A Jetty's [SslContextFactory]. By default it trusts all the certificates. * A Jetty's [SslContextFactory]. By default it trusts all the certificates.
*/ */
var sslContextFactory: SslContextFactory = SslContextFactory(true) var sslContextFactory: SslContextFactory = SslContextFactory()
} }
...@@ -171,6 +171,8 @@ private fun parseUri(text: CharSequence, range: MutableRange): CharSequence { ...@@ -171,6 +171,8 @@ private fun parseUri(text: CharSequence, range: MutableRange): CharSequence {
private val versions = AsciiCharTree.build(listOf("HTTP/1.0", "HTTP/1.1")) private val versions = AsciiCharTree.build(listOf("HTTP/1.0", "HTTP/1.1"))
private fun parseVersion(text: CharSequence, range: MutableRange): CharSequence { private fun parseVersion(text: CharSequence, range: MutableRange): CharSequence {
skipSpaces(text, range) skipSpaces(text, range)
check(range.start < range.end) { "Failed to parse version: $text" }
val exact = versions.search(text, range.start, range.end) { ch, _ -> ch == ' ' }.singleOrNull() val exact = versions.search(text, range.start, range.end) { ch, _ -> ch == ' ' }.singleOrNull()
if (exact != null) { if (exact != null) {
range.start += exact.length range.start += exact.length
......
package io.ktor.network.tls
import io.ktor.network.tls.extensions.*
import java.security.*
internal class CertificateInfo(
val types: ByteArray,
val hashAndSign: Array<HashAndSign>,
val authorities: Set<Principal>
)
package io.ktor.network.tls
import java.security.*
internal data class EncryptionInfo(
val serverPublic: PublicKey,
val clientPublic: PublicKey,
val clientPrivate: PrivateKey
)
...@@ -5,7 +5,7 @@ import kotlinx.io.core.* ...@@ -5,7 +5,7 @@ import kotlinx.io.core.*
@Suppress("KDocMissingDocumentation") @Suppress("KDocMissingDocumentation")
@InternalAPI @InternalAPI
class TLSRecord( internal class TLSRecord(
val type: TLSRecordType = TLSRecordType.Handshake, val type: TLSRecordType = TLSRecordType.Handshake,
val version: TLSVersion = TLSVersion.TLS12, val version: TLSVersion = TLSVersion.TLS12,
val packet: ByteReadPacket = ByteReadPacket.Empty val packet: ByteReadPacket = ByteReadPacket.Empty
...@@ -13,7 +13,7 @@ class TLSRecord( ...@@ -13,7 +13,7 @@ class TLSRecord(
@Suppress("KDocMissingDocumentation") @Suppress("KDocMissingDocumentation")
@InternalAPI @InternalAPI
class TLSHandshake { internal class TLSHandshake {
var type: TLSHandshakeType = TLSHandshakeType.HelloRequest var type: TLSHandshakeType = TLSHandshakeType.HelloRequest
var packet = ByteReadPacket.Empty var packet = ByteReadPacket.Empty
} }
...@@ -128,7 +128,7 @@ private suspend fun ByteReadChannel.readTLSVersion() = ...@@ -128,7 +128,7 @@ private suspend fun ByteReadChannel.readTLSVersion() =
private fun ByteReadPacket.readTLSVersion() = private fun ByteReadPacket.readTLSVersion() =
TLSVersion.byCode(readShort().toInt() and 0xffff) TLSVersion.byCode(readShort().toInt() and 0xffff)
private fun ByteReadPacket.readTripleByteLength(): Int = (readByte().toInt() and 0xff shl 16) or internal fun ByteReadPacket.readTripleByteLength(): Int = (readByte().toInt() and 0xff shl 16) or
(readShort().toInt() and 0xffff) (readShort().toInt() and 0xffff)
internal suspend fun ByteReadChannel.readShortCompatible(): Int { internal suspend fun ByteReadChannel.readShortCompatible(): Int {
......
...@@ -4,6 +4,7 @@ import io.ktor.network.tls.extensions.* ...@@ -4,6 +4,7 @@ import io.ktor.network.tls.extensions.*
import kotlinx.coroutines.io.* import kotlinx.coroutines.io.*
import kotlinx.io.core.* import kotlinx.io.core.*
import java.security.* import java.security.*
import java.security.cert.*
import java.security.interfaces.* import java.security.interfaces.*
import java.security.spec.* import java.security.spec.*
import javax.crypto.* import javax.crypto.*
...@@ -65,6 +66,20 @@ internal fun BytePacketBuilder.writeTLSClientHello( ...@@ -65,6 +66,20 @@ internal fun BytePacketBuilder.writeTLSClientHello(
} }
} }
internal fun BytePacketBuilder.writeTLSCertificates(certificates: Array<X509Certificate>) {
val chain = buildPacket {
for (certificate in certificates) {
val certificateBytes = certificate.encoded!!
writeTripleByteLength(certificateBytes.size)
writeFully(certificateBytes)
}
}
writeTripleByteLength(chain.remaining.toInt())
writePacket(chain)
}
internal fun BytePacketBuilder.writeEncryptedPreMasterSecret( internal fun BytePacketBuilder.writeEncryptedPreMasterSecret(
preSecret: ByteArray, preSecret: ByteArray,
publicKey: PublicKey, publicKey: PublicKey,
......
...@@ -2,37 +2,47 @@ package io.ktor.network.tls ...@@ -2,37 +2,47 @@ package io.ktor.network.tls
import io.ktor.network.sockets.* import io.ktor.network.sockets.*
import kotlinx.coroutines.io.* import kotlinx.coroutines.io.*
import java.security.*
import javax.net.ssl.* import javax.net.ssl.*
import kotlin.coroutines.* import kotlin.coroutines.*
/** /**
* Make [Socket] connection secure with TLS. * Make [Socket] connection secure with TLS using [TLSConfig].
*/ */
suspend fun Socket.tls( suspend fun Socket.tls(
coroutineContext: CoroutineContext, coroutineContext: CoroutineContext, config: TLSConfig
trustManager: X509TrustManager? = null,
randomAlgorithm: String = "NativePRNGNonBlocking",
cipherSuites: List<CipherSuite> = CIOCipherSuites.SupportedSuites,
serverName: String? = null
): Socket { ): Socket {
val reader = openReadChannel() val reader = openReadChannel()
val writer = openWriteChannel() val writer = openWriteChannel()
val session = try { return try {
TLSClientSession( openTLSSession(this, reader, writer, config, coroutineContext)
reader, writer, trustManager, randomAlgorithm, cipherSuites, serverName, coroutineContext
).also { it.start() }
} catch (cause: Throwable) { } catch (cause: Throwable) {
reader.cancel(cause) reader.cancel(cause)
writer.close(cause) writer.close(cause)
close() close()
throw cause throw cause
} }
return TLSSocketImpl(session, this)
} }
private class TLSSocketImpl(val session: TLSClientSession, val delegate: Socket) : Socket by delegate { /**
override fun attachForReading(channel: ByteChannel): WriterJob = session.attachForReading(channel) * Make [Socket] connection secure with TLS.
override fun attachForWriting(channel: ByteChannel): ReaderJob = session.attachForWriting(channel) */
suspend fun Socket.tls(
coroutineContext: CoroutineContext,
trustManager: X509TrustManager? = null,
randomAlgorithm: String = "NativePRNGNonBlocking",
cipherSuites: List<CipherSuite> = CIOCipherSuites.SupportedSuites,
serverName: String? = null
): Socket = tls(coroutineContext) {
this.trustManager = trustManager
this.random = SecureRandom.getInstance(randomAlgorithm)
this.cipherSuites = cipherSuites
this.serverName = serverName
} }
/**
* Make [Socket] connection secure with TLS configured with [block].
*/
suspend fun Socket.tls(coroutineContext: CoroutineContext, block: TLSConfigBuilder.() -> Unit = {}): Socket =
tls(coroutineContext, TLSConfigBuilder().apply(block).build())
package io.ktor.network.tls package io.ktor.network.tls
import io.ktor.network.tls.SecretExchangeType.* import io.ktor.network.tls.SecretExchangeType.*
import io.ktor.network.tls.certificates.*
import io.ktor.network.tls.cipher.* import io.ktor.network.tls.cipher.*
import io.ktor.network.tls.extensions.* import io.ktor.network.tls.extensions.*
import kotlinx.coroutines.* import kotlinx.coroutines.*
...@@ -15,26 +16,17 @@ import java.security.spec.* ...@@ -15,26 +16,17 @@ import java.security.spec.*
import javax.crypto.* import javax.crypto.*
import javax.crypto.spec.* import javax.crypto.spec.*
import javax.net.ssl.* import javax.net.ssl.*
import javax.security.auth.x500.*
import kotlin.coroutines.* import kotlin.coroutines.*
private data class EncryptionInfo(
val serverPublic: PublicKey,
val clientPublic: PublicKey,
val clientPrivate: PrivateKey
)
internal class TLSClientHandshake( internal class TLSClientHandshake(
rawInput: ByteReadChannel, rawInput: ByteReadChannel,
rawOutput: ByteWriteChannel, rawOutput: ByteWriteChannel,
override val coroutineContext: CoroutineContext, private val config: TLSConfig,
private val trustManager: X509TrustManager? = null, override val coroutineContext: CoroutineContext
randomAlgorithm: String = "NativePRNGNonBlocking",
private val cipherSuites: List<CipherSuite>,
private val serverName: String? = null
) : CoroutineScope { ) : CoroutineScope {
private val digest = Digest() private val digest = Digest()
private val random = SecureRandom.getInstance(randomAlgorithm)!! private val clientSeed: ByteArray = config.random.generateClientSeed()
private val clientSeed: ByteArray = random.generateClientSeed()
@Volatile @Volatile
private lateinit var serverHello: TLSServerHello private lateinit var serverHello: TLSServerHello
...@@ -102,10 +94,11 @@ internal class TLSClientHandshake( ...@@ -102,10 +94,11 @@ internal class TLSClientHandshake(
val output: SendChannel<TLSRecord> = actor { val output: SendChannel<TLSRecord> = actor {
var useCipher = false var useCipher = false
channel.consumeEach { rawRecord -> for (rawRecord in channel) {
try { try {
val record = if (useCipher) cipher.encrypt(rawRecord) else rawRecord val record = if (useCipher) cipher.encrypt(rawRecord) else rawRecord
if (rawRecord.type == TLSRecordType.ChangeCipherSpec) useCipher = true if (rawRecord.type == TLSRecordType.ChangeCipherSpec) useCipher = true
rawOutput.writeRecord(record) rawOutput.writeRecord(record)
} catch (cause: Throwable) { } catch (cause: Throwable) {
channel.close(cause) channel.close(cause)
...@@ -121,7 +114,7 @@ internal class TLSClientHandshake( ...@@ -121,7 +114,7 @@ internal class TLSClientHandshake(
val record = input.receive() val record = input.receive()
val packet = record.packet val packet = record.packet
while (packet.remaining > 0) { while (packet.isNotEmpty) {
val handshake = packet.readTLSHandshake() val handshake = packet.readTLSHandshake()
if (handshake.type == TLSHandshakeType.HelloRequest) continue if (handshake.type == TLSHandshakeType.HelloRequest) continue
if (handshake.type != TLSHandshakeType.Finished) { if (handshake.type != TLSHandshakeType.Finished) {
...@@ -139,17 +132,19 @@ internal class TLSClientHandshake( ...@@ -139,17 +132,19 @@ internal class TLSClientHandshake(
} }
suspend fun negotiate() { suspend fun negotiate() {
sendClientHello() digest.use {
serverHello = receiveServerHello() sendClientHello()
serverHello = receiveServerHello()
verifyHello(serverHello) verifyHello(serverHello)
handleCertificatesAndKeys() handleCertificatesAndKeys()
receiveServerFinished() receiveServerFinished()
}
} }
private fun verifyHello(serverHello: TLSServerHello) { private fun verifyHello(serverHello: TLSServerHello) {
val suite = serverHello.cipherSuite val suite = serverHello.cipherSuite
check(suite in cipherSuites) { "Unsupported cipher suite ${suite.name} in SERVER_HELLO" } check(suite in config.cipherSuites) { "Unsupported cipher suite ${suite.name} in SERVER_HELLO" }
val clientExchanges = SupportedSignatureAlgorithms.filter { val clientExchanges = SupportedSignatureAlgorithms.filter {
it.hash == suite.hash && it.sign == suite.signatureAlgorithm it.hash == suite.hash && it.sign == suite.signatureAlgorithm
...@@ -172,7 +167,7 @@ internal class TLSClientHandshake( ...@@ -172,7 +167,7 @@ internal class TLSClientHandshake(
sendHandshakeRecord(TLSHandshakeType.ClientHello) { sendHandshakeRecord(TLSHandshakeType.ClientHello) {
// TODO: support session id // TODO: support session id
writeTLSClientHello( writeTLSClientHello(
TLSVersion.TLS12, cipherSuites, clientSeed, ByteArray(32), serverName TLSVersion.TLS12, config.cipherSuites, clientSeed, ByteArray(32), config.serverName
) )
} }
} }
...@@ -190,7 +185,7 @@ internal class TLSClientHandshake( ...@@ -190,7 +185,7 @@ internal class TLSClientHandshake(
private suspend fun handleCertificatesAndKeys() { private suspend fun handleCertificatesAndKeys() {
val exchangeType = serverHello.cipherSuite.exchangeType val exchangeType = serverHello.cipherSuite.exchangeType
var serverCertificate: Certificate? = null var serverCertificate: Certificate? = null
var certificateRequested = false var certificateInfo: CertificateInfo? = null
var encryptionInfo: EncryptionInfo? = null var encryptionInfo: EncryptionInfo? = null
while (true) { while (true) {
...@@ -203,7 +198,7 @@ internal class TLSClientHandshake( ...@@ -203,7 +198,7 @@ internal class TLSClientHandshake(
val x509s = certs.filterIsInstance<X509Certificate>() val x509s = certs.filterIsInstance<X509Certificate>()
if (x509s.isEmpty()) throw TLSException("Server sent no certificate") if (x509s.isEmpty()) throw TLSException("Server sent no certificate")
val manager = trustManager ?: findTrustManager() val manager = config.trustManager
manager.checkServerTrusted(x509s.toTypedArray(), exchangeType.jvmName) manager.checkServerTrusted(x509s.toTypedArray(), exchangeType.jvmName)
serverCertificate = x509s.firstOrNull { certificate -> serverCertificate = x509s.firstOrNull { certificate ->
...@@ -216,8 +211,32 @@ internal class TLSClientHandshake( ...@@ -216,8 +211,32 @@ internal class TLSClientHandshake(
} ?: throw TLSException("No suitable server certificate received: $certs") } ?: throw TLSException("No suitable server certificate received: $certs")
} }
TLSHandshakeType.CertificateRequest -> { TLSHandshakeType.CertificateRequest -> {
certificateRequested = true val typeCount = packet.readByte().toInt() and 0xFF
check(packet.remaining == 0L) val types = packet.readBytes(typeCount)
val hashAndSignCount = packet.readShort().toInt() and 0xFFFF
val hashAndSign = mutableListOf<HashAndSign>()
repeat(hashAndSignCount / 2) {
val hash = packet.readByte()
val sign = packet.readByte()
hashAndSign += HashAndSign(hash, sign)
}
val authoritiesSize = packet.readShort().toInt() and 0xFFFF
val authorities = mutableSetOf<Principal>()
var position = 0
while (position < authoritiesSize) {
val size = packet.readShort().toInt() and 0xFFFF
position += size
val authority = packet.readBytes(size)
authorities += X500Principal(authority)
}
certificateInfo = CertificateInfo(types, hashAndSign.toTypedArray(), authorities)
check(packet.isEmpty)
} }
TLSHandshakeType.ServerKeyExchange -> { TLSHandshakeType.ServerKeyExchange -> {
when (exchangeType) { when (exchangeType) {
...@@ -258,7 +277,7 @@ internal class TLSClientHandshake( ...@@ -258,7 +277,7 @@ internal class TLSClientHandshake(
handleServerDone( handleServerDone(
exchangeType, exchangeType,
serverCertificate!!, serverCertificate!!,
certificateRequested, certificateInfo,
encryptionInfo encryptionInfo
) )
return return
...@@ -271,26 +290,27 @@ internal class TLSClientHandshake( ...@@ -271,26 +290,27 @@ internal class TLSClientHandshake(
private suspend fun handleServerDone( private suspend fun handleServerDone(
exchangeType: SecretExchangeType, exchangeType: SecretExchangeType,
serverCertificate: Certificate, serverCertificate: Certificate,
certificateRequested: Boolean, certificateInfo: CertificateInfo?,
encryptionInfo: EncryptionInfo? encryptionInfo: EncryptionInfo?
) { ) {
if (certificateRequested) sendClientCertificate() val chain = certificateInfo?.let { sendClientCertificate(it) }
val preSecret: ByteArray = generatePreSecret(encryptionInfo) val preSecret: ByteArray = generatePreSecret(encryptionInfo)
sendClientKeyExchange( sendClientKeyExchange(
exchangeType, exchangeType,
serverCertificate, serverCertificate,
preSecret, preSecret,
certificateRequested,
encryptionInfo encryptionInfo
) )
masterSecret = masterSecret( masterSecret = masterSecret(
SecretKeySpec(preSecret, serverHello.cipherSuite.hash.macName), SecretKeySpec(preSecret, serverHello.cipherSuite.hash.macName),
clientSeed, serverHello.serverSeed clientSeed, serverHello.serverSeed
) )
preSecret.fill(0) preSecret.fill(0)
if (certificateRequested) sendClientCertificateVerify() certificateInfo?.let { sendClientCertificateVerify(it, chain!!) }
sendChangeCipherSpec() sendChangeCipherSpec()
sendClientFinished(masterSecret) sendClientFinished(masterSecret)
...@@ -298,7 +318,7 @@ internal class TLSClientHandshake( ...@@ -298,7 +318,7 @@ internal class TLSClientHandshake(
private fun generatePreSecret(encryptionInfo: EncryptionInfo?): ByteArray = private fun generatePreSecret(encryptionInfo: EncryptionInfo?): ByteArray =
when (serverHello.cipherSuite.exchangeType) { when (serverHello.cipherSuite.exchangeType) {
SecretExchangeType.RSA -> random.generateSeed(48)!!.also { SecretExchangeType.RSA -> config.random.generateSeed(48)!!.also {
it[0] = 0x03 it[0] = 0x03
it[1] = 0x03 it[1] = 0x03
} }
...@@ -314,17 +334,14 @@ internal class TLSClientHandshake( ...@@ -314,17 +334,14 @@ internal class TLSClientHandshake(
exchangeType: SecretExchangeType, exchangeType: SecretExchangeType,
serverCertificate: Certificate, serverCertificate: Certificate,
preSecret: ByteArray, preSecret: ByteArray,
certificateRequested: Boolean,
encryptionInfo: EncryptionInfo? encryptionInfo: EncryptionInfo?
) { ) {
val packet = when (exchangeType) { val packet = when (exchangeType) {
RSA -> buildPacket { RSA -> buildPacket {
writeEncryptedPreMasterSecret(preSecret, serverCertificate.publicKey, random) writeEncryptedPreMasterSecret(preSecret, serverCertificate.publicKey, config.random)
} }
ECDHE -> buildPacket { ECDHE -> buildPacket {
if (certificateRequested) return@buildPacket // Key exchange has already completed implicit in the certificate message.
if (encryptionInfo == null) throw TLSException("ECDHE: Encryption info should be provided") if (encryptionInfo == null) throw TLSException("ECDHE: Encryption info should be provided")
writePublicKeyUncompressed(encryptionInfo.clientPublic) writePublicKeyUncompressed(encryptionInfo.clientPublic)
} }
} }
...@@ -332,12 +349,55 @@ internal class TLSClientHandshake( ...@@ -332,12 +349,55 @@ internal class TLSClientHandshake(
sendHandshakeRecord(TLSHandshakeType.ClientKeyExchange) { writePacket(packet) } sendHandshakeRecord(TLSHandshakeType.ClientKeyExchange) { writePacket(packet) }
} }
private fun sendClientCertificate() { private suspend fun sendClientCertificate(info: CertificateInfo): CertificateAndKey? {
throw TLSException("Client certificates unsupported") val chainAndKey = config.certificates.find { candidate ->
val leaf = candidate.certificateChain.first()
val validAlgorithm = when (leaf.publicKey.algorithm) {
"RSA" -> info.types.contains(CertificateType.RSA)
"DSS" -> info.types.contains(CertificateType.DSS)
else -> false
}
if (!validAlgorithm) return@find false
val hasHashAndSignInCommon = info.hashAndSign.none {
it.name.equals(leaf.sigAlgName, ignoreCase = true)
}
if (hasHashAndSignInCommon) return@find false
info.authorities.isEmpty() || candidate.certificateChain.any { it.issuerDN in info.authorities }
}
sendHandshakeRecord(TLSHandshakeType.Certificate) {
writeTLSCertificates(chainAndKey?.certificateChain ?: emptyArray())
}
return chainAndKey
} }
private fun sendClientCertificateVerify() { private suspend fun sendClientCertificateVerify(info: CertificateInfo, certificateAndKey: CertificateAndKey) {
throw TLSException("Client certificates unsupported") val leaf = certificateAndKey.certificateChain.first()
val hashAndSign = info.hashAndSign.firstOrNull {
it.name.equals(leaf.sigAlgName, ignoreCase = true)
} ?: return
if (hashAndSign.sign == SignatureAlgorithm.DSA) return
val sign = Signature.getInstance(certificateAndKey.certificateChain.first().sigAlgName)!!
sign.initSign(certificateAndKey.key)
sendHandshakeRecord(TLSHandshakeType.CertificateVerify) {
writeByte(hashAndSign.hash.code)
writeByte(hashAndSign.sign.code)
digest.state.preview { sign.update(it.readBytes()) }
val signBytes = sign.sign()!!
writeShort(signBytes.size.toShort())
writeFully(signBytes)
}
} }
private suspend fun sendChangeCipherSpec() { private suspend fun sendChangeCipherSpec() {
...@@ -382,17 +442,11 @@ internal class TLSClientHandshake( ...@@ -382,17 +442,11 @@ internal class TLSClientHandshake(
} }
digest.update(recordBody) digest.update(recordBody)
output.send(TLSRecord(TLSRecordType.Handshake, packet = recordBody)) val element = TLSRecord(TLSRecordType.Handshake, packet = recordBody)
output.send(element)
} }
} }
private fun findTrustManager(): X509TrustManager {
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
factory.init(null as KeyStore?)
val manager = factory.trustManagers
return manager.first { it is X509TrustManager } as X509TrustManager
}
private fun SecureRandom.generateClientSeed(): ByteArray { private fun SecureRandom.generateClientSeed(): ByteArray {
return generateSeed(32)!!.also { return generateSeed(32)!!.also {
......
...@@ -8,29 +8,25 @@ import kotlinx.coroutines.io.* ...@@ -8,29 +8,25 @@ import kotlinx.coroutines.io.*
import kotlinx.io.core.* import kotlinx.io.core.*
import kotlinx.io.pool.* import kotlinx.io.pool.*
import java.nio.* import java.nio.*
import javax.net.ssl.*
import kotlin.coroutines.* import kotlin.coroutines.*
internal class TLSClientSession( internal suspend fun openTLSSession(
rawInput: ByteReadChannel, socket: Socket,
rawOutput: ByteWriteChannel, input: ByteReadChannel, output: ByteWriteChannel,
trustManager: X509TrustManager?, config: TLSConfig,
randomAlgorithm: String, context: CoroutineContext
cipherSuites: List<CipherSuite>, ): Socket {
serverName: String?, val handshake = TLSClientHandshake(input, output, config, context)
override val coroutineContext: CoroutineContext handshake.negotiate()
) : CoroutineScope, AReadable, AWritable { return TLSSocket(handshake.input, handshake.output, socket, context)
private val handshaker = TLSClientHandshake( }
rawInput, rawOutput, coroutineContext,
trustManager, randomAlgorithm, cipherSuites, serverName
)
private val input = handshaker.input
private val output = handshaker.output
suspend fun start() { private class TLSSocket(
handshaker.negotiate() private val input: ReceiveChannel<TLSRecord>,
} private val output: SendChannel<TLSRecord>,
socket: Socket,
override val coroutineContext: CoroutineContext
) : CoroutineScope, Socket by socket {
override fun attachForReading(channel: ByteChannel): WriterJob = writer(coroutineContext, channel) { override fun attachForReading(channel: ByteChannel): WriterJob = writer(coroutineContext, channel) {
appDataInputLoop(this.channel) appDataInputLoop(this.channel)
......
package io.ktor.network.tls
import java.security.*
import java.security.cert.*
import javax.net.ssl.*
/**
* TLS configuration.
* @property trustManager: Custom [X509TrustManager] to verify server authority. Use system by default.
* @property random: [SecureRandom] to use in encryption.
* @property certificates: list of client certificate chains with private keys.
* @property cipherSuites: list of allowed [CipherSuite]s.
* @property serverName: custom server name for TLS server name extension.
*/
class TLSConfig(
val random: SecureRandom,
val certificates: List<CertificateAndKey>,
val trustManager: X509TrustManager,
val cipherSuites: List<CipherSuite>,
val serverName: String?
)
/**
* Client certificate chain with private key.
* @property certificateChain: client certificate chain.
* @property key: [PrivateKey] for certificate chain.
*/
class CertificateAndKey(val certificateChain: Array<X509Certificate>, val key: PrivateKey)
package io.ktor.network.tls
import java.security.*
import java.security.cert.*
import javax.net.ssl.*
/**
* [TLSConfig] builder.
*/
class TLSConfigBuilder {
/**
* List of client certificate chains with private keys.
*/
val certificates: MutableList<CertificateAndKey> = mutableListOf()
/**
* [SecureRandom] to use in encryption.
*/
var random: SecureRandom? = null
/**
* Custom [X509TrustManager] to verify server authority.
*
* Use system by default.
*/
var trustManager: TrustManager? = null
set(value) {
value?.let {
check(it is X509TrustManager) {
"Failed to set [trustManager]: $value. Only [X509TrustManager] supported."
}
}
field = value
}
/**
* List of allowed [CipherSuite]s.
*/
var cipherSuites: List<CipherSuite> = CIOCipherSuites.SupportedSuites
/**
* Custom server name for TLS server name extension.
* See also: https://en.wikipedia.org/wiki/Server_Name_Indication
*/
var serverName: String? = null
/**
* Create [TLSConfig].
*/
fun build(): TLSConfig = TLSConfig(
random ?: SecureRandom.getInstanceStrong(),
certificates, trustManager as? X509TrustManager ?: findTrustManager(),
cipherSuites, serverName
)
}
/**
* Add client certificate chain to use.
*/
fun TLSConfigBuilder.addCertificateChain(chain: Array<X509Certificate>, key: PrivateKey) {
certificates += CertificateAndKey(chain, key)
}
/**
* Add client certificates from [store].
*/
fun TLSConfigBuilder.addKeyStore(store: KeyStore, password: CharArray) {
val keyManagerAlgorithm = KeyManagerFactory.getDefaultAlgorithm()!!
val keyManagerFactory = KeyManagerFactory.getInstance(keyManagerAlgorithm)!!
keyManagerFactory.init(store, password)
val managers = keyManagerFactory.keyManagers.filterIsInstance<X509KeyManager>()
val aliases = store.aliases()!!
loop@ for (alias in aliases) {
val chain = store.getCertificateChain(alias)
val allX509 = chain.all { it is X509Certificate }
check(allX509) { "Fail to add key store $store. Only X509 certificate format supported." }
for (manager in managers) {
val key = manager.getPrivateKey(alias) ?: continue
val map = chain.map { it as X509Certificate }
addCertificateChain(map.toTypedArray(), key)
continue@loop
}
throw NoPrivateKeyException(alias, store)
}
}
/**
* Throws if failed to find [PrivateKey] for any alias in [KeyStore].
*/
class NoPrivateKeyException(
alias: String, store: KeyStore
) : IllegalStateException("Failed to find private key for alias $alias. Please check your key store: $store")
private fun findTrustManager(): X509TrustManager {
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())!!
factory.init(null as KeyStore?)
val manager = factory.trustManagers!!
return manager.filterIsInstance<X509TrustManager>().first()
}
...@@ -2,11 +2,11 @@ package io.ktor.network.tls ...@@ -2,11 +2,11 @@ package io.ktor.network.tls
import io.ktor.http.cio.internals.* import io.ktor.http.cio.internals.*
import kotlinx.io.core.* import kotlinx.io.core.*
import kotlinx.io.core.Closeable
import java.security.* import java.security.*
internal class Digest : Closeable { internal fun Digest(): Digest = Digest(BytePacketBuilder())
private val state = BytePacketBuilder()
internal inline class Digest(val state: BytePacketBuilder) : Closeable {
fun update(packet: ByteReadPacket) = synchronized(this) { fun update(packet: ByteReadPacket) = synchronized(this) {
if (packet.isEmpty) return if (packet.isEmpty) return
......
package io.ktor.network.tls.certificates
import io.ktor.network.tls.*
/**
* Type of client certificate.
* see also https://tools.ietf.org/html/rfc5246#section-7.4.4
*
* @property code numeric algorithm codes
*/
@Suppress("KDocMissingDocumentation")
internal object CertificateType {
val RSA: Byte = 1
val DSS: Byte = 2
val RSA_FIXED_DH: Byte = 3
val DSS_FIXED_DH: Byte = 4
val RSA_EPHEMERAL_DH_RESERVED: Byte = 5
val DSS_EPHEMERAL_DH_RESERVED: Byte = 6
val FORTEZZA_DMS_RESERVED: Byte = 20
}
...@@ -429,4 +429,4 @@ private fun BytePacketBuilder.writeDerInt(value: Int) { ...@@ -429,4 +429,4 @@ private fun BytePacketBuilder.writeDerInt(value: Int) {
// else -> append(' ') // else -> append(' ')
// } // }
// } // }
//} //}
\ No newline at end of file
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать