Skip to content

Commit 3cb4d5a

Browse files
authored
Correctly override connection state using sslContextCallback (#499)
Motivation: The sslContextCallback overrides modify the SSL_CTX object itself. That's not useful: in addition to having a global effect, which is rarely what we want, certain modifications don't apply that late, and setting the credentials is one of them. The result is that the callback didn't actually do anything. Not that useful. Modifications: - Extract the cert chain setting functions from NIOSSLContext and move them to general-purpose functions. - Add support there for calling them on a SSL * as well as on a SSL_CTX * - Call them in an override function on SSLConnection - Call that! - Add a bunch of tests. Result: The callback actually does what it should.
1 parent 9fc4828 commit 3cb4d5a

File tree

8 files changed

+277
-108
lines changed

8 files changed

+277
-108
lines changed

Sources/NIOSSL/SSLConnection.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,18 @@ extension SSLConnection {
495495
return try buffers.map { try NIOSSLCertificate(bytes: $0, format: .der) }
496496
}
497497
}
498+
499+
func applyOverride(_ changes: NIOSSLContextConfigurationOverride) throws {
500+
let connection = UnsafeKeyAndChainTarget.ssl(self.ssl)
501+
if let chain = changes.certificateChain {
502+
try connection.useCertificateChain(chain)
503+
}
504+
505+
// Attempt to load the new private key and abort on failure
506+
if let pkey = changes.privateKey {
507+
try connection.usePrivateKeySource(pkey)
508+
}
509+
}
498510
}
499511

500512
extension SSLConnection.PeerCertificateChainBuffers: RandomAccessCollection {

Sources/NIOSSL/SSLContext.swift

Lines changed: 5 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,8 @@ private func sslContextCallback(ssl: OpaquePointer?, arg: UnsafeMutableRawPointe
262262
case .success(let changes):
263263
do {
264264
// Attempt to load the new certificate chain and abort on failure
265-
if let chain = changes.certificateChain {
266-
try NIOSSLContext.useCertificateChain(chain, context: parentSwiftContext.sslContext)
267-
}
268-
269-
// Attempt to load the new private key and abort on failure
270-
if let pkey = changes.privateKey {
271-
try NIOSSLContext.usePrivateKeySource(pkey, context: parentSwiftContext.sslContext)
272-
}
265+
let ssl = SSLConnection.loadConnectionFromSSL(ssl)
266+
try ssl.applyOverride(changes)
273267

274268
// We must return 1 to signal a successful load of the new context
275269
return 1
@@ -450,10 +444,11 @@ public final class NIOSSLContext {
450444
)
451445
}
452446

453-
try NIOSSLContext.useCertificateChain(configuration.certificateChain, context: context)
447+
let handle = UnsafeKeyAndChainTarget.sslContext(context)
448+
try handle.useCertificateChain(configuration.certificateChain)
454449

455450
if let pkey = configuration.privateKey {
456-
try NIOSSLContext.usePrivateKeySource(pkey, context: context)
451+
try handle.usePrivateKeySource(pkey)
457452
}
458453

459454
if configuration.encodedApplicationProtocols.count > 0 {
@@ -579,101 +574,6 @@ extension NIOSSLContext {
579574
}
580575

581576
extension NIOSSLContext {
582-
fileprivate static func useCertificateChain(
583-
_ certificateChain: [NIOSSLCertificateSource],
584-
context: OpaquePointer
585-
) throws {
586-
var leaf = true
587-
for source in certificateChain {
588-
switch source {
589-
case .file(let p):
590-
NIOSSLContext.useCertificateChainFile(p, context: context)
591-
leaf = false
592-
case .certificate(let cert):
593-
if leaf {
594-
try NIOSSLContext.setLeafCertificate(cert, context: context)
595-
leaf = false
596-
} else {
597-
try NIOSSLContext.addAdditionalChainCertificate(cert, context: context)
598-
}
599-
}
600-
}
601-
}
602-
603-
private static func useCertificateChainFile(_ path: String, context: OpaquePointer) {
604-
// TODO(cory): This shouldn't be an assert but should instead be actual error handling.
605-
// assert(path.isFileURL)
606-
let result = path.withCString { (pointer) -> CInt in
607-
CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer)
608-
}
609-
610-
// TODO(cory): again, some error handling would be good.
611-
precondition(result == 1)
612-
}
613-
614-
private static func setLeafCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws {
615-
let rc = cert.withUnsafeMutableX509Pointer { ref in
616-
CNIOBoringSSL_SSL_CTX_use_certificate(context, ref)
617-
}
618-
guard rc == 1 else {
619-
throw NIOSSLError.failedToLoadCertificate
620-
}
621-
}
622-
623-
private static func addAdditionalChainCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws {
624-
let rc = cert.withUnsafeMutableX509Pointer { ref in
625-
CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref)
626-
}
627-
guard rc == 1 else {
628-
throw NIOSSLError.failedToLoadCertificate
629-
}
630-
}
631-
632-
fileprivate static func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource, context: OpaquePointer) throws {
633-
switch privateKey {
634-
case .file(let p):
635-
try NIOSSLContext.usePrivateKeyFile(p, context: context)
636-
case .privateKey(let key):
637-
try NIOSSLContext.setPrivateKey(key, context: context)
638-
}
639-
}
640-
641-
private static func setPrivateKey(_ key: NIOSSLPrivateKey, context: OpaquePointer) throws {
642-
switch key.representation {
643-
case .native:
644-
let rc = key.withUnsafeMutableEVPPKEYPointer { ref in
645-
CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref)
646-
}
647-
guard 1 == rc else {
648-
throw NIOSSLError.failedToLoadPrivateKey
649-
}
650-
case .custom:
651-
CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod)
652-
}
653-
}
654-
655-
private static func usePrivateKeyFile(_ path: String, context: OpaquePointer) throws {
656-
let pathExtension = path.split(separator: ".").last
657-
let fileType: CInt
658-
659-
switch pathExtension?.lowercased() {
660-
case .some("pem"):
661-
fileType = SSL_FILETYPE_PEM
662-
case .some("der"), .some("key"):
663-
fileType = SSL_FILETYPE_ASN1
664-
default:
665-
throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path)
666-
}
667-
668-
let result = path.withCString { (pointer) -> CInt in
669-
CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType)
670-
}
671-
672-
guard result == 1 else {
673-
throw NIOSSLError.failedToLoadPrivateKey
674-
}
675-
}
676-
677577
private static func setAlpnProtocols(_ protocols: [[UInt8]], context: OpaquePointer) throws {
678578
// This copy should be done infrequently, so we don't worry too much about it.
679579
let protoBuf = protocols.reduce([UInt8](), +)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftNIO open source project
4+
//
5+
// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
@_implementationOnly import CNIOBoringSSL
15+
16+
enum UnsafeKeyAndChainTarget {
17+
case sslContext(OpaquePointer)
18+
case ssl(OpaquePointer)
19+
20+
func useCertificateChain(
21+
_ certificateChain: [NIOSSLCertificateSource]
22+
) throws {
23+
var leaf = true
24+
for source in certificateChain {
25+
switch source {
26+
case .file(let p):
27+
self.useCertificateChainFile(p)
28+
leaf = false
29+
case .certificate(let cert):
30+
if leaf {
31+
try self.setLeafCertificate(cert)
32+
leaf = false
33+
} else {
34+
try self.addAdditionalChainCertificate(cert)
35+
}
36+
}
37+
}
38+
}
39+
40+
func useCertificateChainFile(_ path: String) {
41+
let result = path.withCString { (pointer) -> CInt in
42+
switch self {
43+
case .sslContext(let context):
44+
CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer)
45+
case .ssl(let ssl):
46+
CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(ssl, pointer)
47+
}
48+
}
49+
50+
precondition(result == 1)
51+
}
52+
53+
func setLeafCertificate(_ cert: NIOSSLCertificate) throws {
54+
let rc = cert.withUnsafeMutableX509Pointer { ref in
55+
switch self {
56+
case .sslContext(let context):
57+
CNIOBoringSSL_SSL_CTX_use_certificate(context, ref)
58+
case .ssl(let ssl):
59+
CNIOBoringSSL_SSL_use_certificate(ssl, ref)
60+
}
61+
}
62+
guard rc == 1 else {
63+
throw NIOSSLError.failedToLoadCertificate
64+
}
65+
}
66+
67+
func addAdditionalChainCertificate(_ cert: NIOSSLCertificate) throws {
68+
let rc = cert.withUnsafeMutableX509Pointer { ref in
69+
switch self {
70+
case .sslContext(let context):
71+
CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref)
72+
case .ssl(let ssl):
73+
CNIOBoringSSL_SSL_add1_chain_cert(ssl, ref)
74+
}
75+
}
76+
guard rc == 1 else {
77+
throw NIOSSLError.failedToLoadCertificate
78+
}
79+
}
80+
81+
func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource) throws {
82+
switch privateKey {
83+
case .file(let p):
84+
try self.usePrivateKeyFile(p)
85+
case .privateKey(let key):
86+
try self.setPrivateKey(key)
87+
}
88+
}
89+
90+
func setPrivateKey(_ key: NIOSSLPrivateKey) throws {
91+
switch key.representation {
92+
case .native:
93+
let rc = key.withUnsafeMutableEVPPKEYPointer { ref in
94+
switch self {
95+
case .sslContext(let context):
96+
CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref)
97+
case .ssl(let ssl):
98+
CNIOBoringSSL_SSL_use_PrivateKey(ssl, ref)
99+
}
100+
}
101+
guard 1 == rc else {
102+
throw NIOSSLError.failedToLoadPrivateKey
103+
}
104+
case .custom:
105+
switch self {
106+
case .sslContext(let context):
107+
CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod)
108+
case .ssl(let ssl):
109+
CNIOBoringSSL_SSL_set_private_key_method(ssl, customPrivateKeyMethod)
110+
}
111+
}
112+
}
113+
114+
func usePrivateKeyFile(_ path: String) throws {
115+
let pathExtension = path.split(separator: ".").last
116+
let fileType: CInt
117+
118+
switch pathExtension?.lowercased() {
119+
case .some("pem"):
120+
fileType = SSL_FILETYPE_PEM
121+
case .some("der"), .some("key"):
122+
fileType = SSL_FILETYPE_ASN1
123+
default:
124+
throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path)
125+
}
126+
127+
let result = path.withCString { (pointer) -> CInt in
128+
switch self {
129+
case .sslContext(let context):
130+
CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType)
131+
case .ssl(let ssl):
132+
CNIOBoringSSL_SSL_use_PrivateKey_file(ssl, pointer, fileType)
133+
}
134+
}
135+
136+
guard result == 1 else {
137+
throw NIOSSLError.failedToLoadPrivateKey
138+
}
139+
}
140+
}

Sources/NIOSSLPerformanceTester/BenchManyWrites.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
import NIOCore
16+
import NIOEmbedded
1617
import NIOSSL
1718

1819
final class BenchManyWrites: Benchmark {

Sources/NIOSSLPerformanceTester/BenchRepeatedHandshakes.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
import NIOCore
16+
import NIOEmbedded
1617
import NIOSSL
1718

1819
final class BenchRepeatedHandshakes: Benchmark {

Tests/NIOSSLTests/NIOSSLIntegrationTest.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,20 @@ public func assertNoThrowWithValue<T>(
4040
}
4141
}
4242

43-
internal func interactInMemory(clientChannel: EmbeddedChannel, serverChannel: EmbeddedChannel) throws {
43+
internal func interactInMemory(
44+
clientChannel: EmbeddedChannel,
45+
serverChannel: EmbeddedChannel,
46+
runLoops: Bool = true
47+
) throws {
4448
var workToDo = true
4549
while workToDo {
4650
workToDo = false
51+
52+
if runLoops {
53+
clientChannel.embeddedEventLoop.run()
54+
serverChannel.embeddedEventLoop.run()
55+
}
56+
4757
let clientDatum = try clientChannel.readOutbound(as: IOData.self)
4858
let serverDatum = try serverChannel.readOutbound(as: IOData.self)
4959

0 commit comments

Comments
 (0)