@@ -13,34 +13,41 @@ use std::{
13
13
///
14
14
/// [CertificateChain] is internally reference counted. The reference counted `T`
15
15
/// must have a drop implementation.
16
- struct CertificateChainHandle {
17
- cert : NonNull < s2n_cert_chain_and_key > ,
16
+ pub ( crate ) struct CertificateChainHandle < ' a > {
17
+ pub ( crate ) cert : NonNull < s2n_cert_chain_and_key > ,
18
18
is_owned : bool ,
19
+ _lifetime : PhantomData < & ' a s2n_cert_chain_and_key > ,
19
20
}
20
21
21
22
// # Safety
22
23
//
23
24
// s2n_cert_chain_and_key objects can be sent across threads.
24
- unsafe impl Send for CertificateChainHandle { }
25
- unsafe impl Sync for CertificateChainHandle { }
25
+ unsafe impl Send for CertificateChainHandle < ' _ > { }
26
+ unsafe impl Sync for CertificateChainHandle < ' _ > { }
26
27
27
- impl CertificateChainHandle {
28
- fn from_owned ( cert : NonNull < s2n_cert_chain_and_key > ) -> Self {
29
- Self {
30
- cert,
28
+ impl CertificateChainHandle < ' _ > {
29
+ /// Allocate an uninitialized CertificateChainHandle.
30
+ ///
31
+ /// Corresponds to [s2n_cert_chain_and_key_new].
32
+ pub ( crate ) fn allocate ( ) -> Result < CertificateChainHandle < ' static > , crate :: error:: Error > {
33
+ crate :: init:: init ( ) ;
34
+ Ok ( CertificateChainHandle {
35
+ cert : unsafe { s2n_cert_chain_and_key_new ( ) . into_result ( ) } ?,
31
36
is_owned : true ,
32
- }
37
+ _lifetime : PhantomData ,
38
+ } )
33
39
}
34
40
35
41
fn from_reference ( cert : NonNull < s2n_cert_chain_and_key > ) -> Self {
36
42
Self {
37
43
cert,
38
44
is_owned : false ,
45
+ _lifetime : PhantomData ,
39
46
}
40
47
}
41
48
}
42
49
43
- impl Drop for CertificateChainHandle {
50
+ impl Drop for CertificateChainHandle < ' _ > {
44
51
/// Corresponds to [s2n_cert_chain_and_key_free].
45
52
fn drop ( & mut self ) {
46
53
// ignore failures since there's not much we can do about it
@@ -53,13 +60,13 @@ impl Drop for CertificateChainHandle {
53
60
}
54
61
55
62
pub struct Builder {
56
- cert : CertificateChain < ' static > ,
63
+ cert_handle : CertificateChainHandle < ' static > ,
57
64
}
58
65
59
66
impl Builder {
60
67
pub fn new ( ) -> Result < Self , Error > {
61
68
Ok ( Self {
62
- cert : CertificateChain :: allocate_owned ( ) ?,
69
+ cert_handle : CertificateChainHandle :: allocate ( ) ?,
63
70
} )
64
71
}
65
72
@@ -73,7 +80,7 @@ impl Builder {
73
80
// `private_key_pem` are not modified.
74
81
// https://github.com/aws/s2n-tls/issues/4140
75
82
s2n_cert_chain_and_key_load_pem_bytes (
76
- self . cert . as_mut_ptr ( ) ,
83
+ self . cert_handle . cert . as_ptr ( ) ,
77
84
chain. as_ptr ( ) as * mut _ ,
78
85
chain. len ( ) as u32 ,
79
86
key. as_ptr ( ) as * mut _ ,
@@ -95,7 +102,7 @@ impl Builder {
95
102
// is not modified
96
103
// https://github.com/aws/s2n-tls/issues/4140
97
104
s2n_cert_chain_and_key_load_public_pem_bytes (
98
- self . cert . as_mut_ptr ( ) ,
105
+ self . cert_handle . cert . as_ptr ( ) ,
99
106
chain. as_ptr ( ) as * mut _ ,
100
107
chain. len ( ) as u32 ,
101
108
)
@@ -109,7 +116,7 @@ impl Builder {
109
116
pub fn set_ocsp_data ( & mut self , data : & [ u8 ] ) -> Result < & mut Self , Error > {
110
117
unsafe {
111
118
s2n_cert_chain_and_key_set_ocsp_data (
112
- self . cert . as_mut_ptr ( ) ,
119
+ self . cert_handle . cert . as_ptr ( ) ,
113
120
data. as_ptr ( ) ,
114
121
data. len ( ) as u32 ,
115
122
)
@@ -122,7 +129,7 @@ impl Builder {
122
129
pub fn build ( self ) -> Result < CertificateChain < ' static > , Error > {
123
130
// This method is currently infallible, but returning a result allows
124
131
// us to add validation in the future.
125
- Ok ( self . cert )
132
+ Ok ( CertificateChain :: from_allocated ( self . cert_handle ) )
126
133
}
127
134
}
128
135
@@ -135,22 +142,16 @@ impl Builder {
135
142
// safe to mutate CertificateChains.
136
143
#[ derive( Clone ) ]
137
144
pub struct CertificateChain < ' a > {
138
- ptr : Arc < CertificateChainHandle > ,
139
- _lifetime : PhantomData < & ' a s2n_cert_chain_and_key > ,
145
+ cert_handle : Arc < CertificateChainHandle < ' a > > ,
140
146
}
141
147
142
148
impl CertificateChain < ' _ > {
143
- /// This allocates a new certificate chain from s2n.
144
- ///
145
- /// Corresponds to [s2n_cert_chain_and_key_new].
146
- pub ( crate ) fn allocate_owned ( ) -> Result < CertificateChain < ' static > , Error > {
147
- crate :: init:: init ( ) ;
148
- unsafe {
149
- let ptr = s2n_cert_chain_and_key_new ( ) . into_result ( ) ?;
150
- Ok ( CertificateChain {
151
- ptr : Arc :: new ( CertificateChainHandle :: from_owned ( ptr) ) ,
152
- _lifetime : PhantomData ,
153
- } )
149
+ /// Construct a CertificateChain from an allocated [CertificateChainHandle].
150
+ pub ( crate ) fn from_allocated (
151
+ handle : CertificateChainHandle < ' static > ,
152
+ ) -> CertificateChain < ' static > {
153
+ CertificateChain {
154
+ cert_handle : Arc :: new ( handle) ,
154
155
}
155
156
}
156
157
@@ -162,8 +163,7 @@ impl CertificateChain<'_> {
162
163
let handle = Arc :: new ( CertificateChainHandle :: from_reference ( ptr) ) ;
163
164
164
165
CertificateChain {
165
- ptr : handle,
166
- _lifetime : PhantomData ,
166
+ cert_handle : handle,
167
167
}
168
168
}
169
169
@@ -202,16 +202,8 @@ impl CertificateChain<'_> {
202
202
self . len ( ) == 0
203
203
}
204
204
205
- /// SAFETY: Only one instance of `CertificateChain` may exist when this method
206
- /// is called. s2n_cert_chain_and_key is not thread-safe, so it is not safe
207
- /// to mutate the certificate chain if references are held across multiple threads.
208
- pub ( crate ) unsafe fn as_mut_ptr ( & mut self ) -> * mut s2n_cert_chain_and_key {
209
- debug_assert_eq ! ( Arc :: strong_count( & self . ptr) , 1 ) ;
210
- self . ptr . cert . as_ptr ( )
211
- }
212
-
213
205
pub ( crate ) fn as_ptr ( & self ) -> * const s2n_cert_chain_and_key {
214
- self . ptr . cert . as_ptr ( ) as * const _
206
+ self . cert_handle . cert . as_ptr ( ) as * const _
215
207
}
216
208
}
217
209
@@ -339,28 +331,28 @@ mod tests {
339
331
#[ test]
340
332
fn reference_count_increment ( ) -> Result < ( ) , crate :: error:: Error > {
341
333
let cert = SniTestCerts :: AlligatorRsa . get ( ) . into_certificate_chain ( ) ;
342
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 1 ) ;
334
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 1 ) ;
343
335
344
336
{
345
337
let mut server = config:: Builder :: new ( ) ;
346
338
server. load_chain ( cert. clone ( ) ) ?;
347
339
348
340
// after being added, the reference count should have increased
349
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 2 ) ;
341
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 2 ) ;
350
342
}
351
343
352
344
// after the config goes out of scope and is dropped, the ref count should
353
345
// decrement
354
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 1 ) ;
346
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 1 ) ;
355
347
Ok ( ( ) )
356
348
}
357
349
358
350
#[ test]
359
351
fn cert_is_dropped ( ) {
360
352
let weak_ref = {
361
353
let cert = SniTestCerts :: AlligatorEcdsa . get ( ) . into_certificate_chain ( ) ;
362
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 1 ) ;
363
- Arc :: downgrade ( & cert. ptr )
354
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 1 ) ;
355
+ Arc :: downgrade ( & cert. cert_handle )
364
356
} ;
365
357
assert_eq ! ( weak_ref. strong_count( ) , 0 ) ;
366
358
assert ! ( weak_ref. upgrade( ) . is_none( ) ) ;
@@ -377,17 +369,17 @@ mod tests {
377
369
let mut test_pair_2 =
378
370
sni_test_pair ( vec ! [ cert. clone( ) ] , None , & [ SniTestCerts :: AlligatorRsa ] ) ?;
379
371
380
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 3 ) ;
372
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 3 ) ;
381
373
382
374
assert ! ( test_pair_1. handshake( ) . is_ok( ) ) ;
383
375
assert ! ( test_pair_2. handshake( ) . is_ok( ) ) ;
384
376
385
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 3 ) ;
377
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 3 ) ;
386
378
387
379
drop ( test_pair_1) ;
388
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 2 ) ;
380
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 2 ) ;
389
381
drop ( test_pair_2) ;
390
- assert_eq ! ( Arc :: strong_count( & cert. ptr ) , 1 ) ;
382
+ assert_eq ! ( Arc :: strong_count( & cert. cert_handle ) , 1 ) ;
391
383
Ok ( ( ) )
392
384
}
393
385
@@ -396,7 +388,7 @@ mod tests {
396
388
// 5 certs in the maximum allowed, 6 should error.
397
389
const FAILING_NUMBER : usize = 6 ;
398
390
let certs = vec ! [ SniTestCerts :: AlligatorRsa . get( ) . into_certificate_chain( ) ; FAILING_NUMBER ] ;
399
- assert_eq ! ( Arc :: strong_count( & certs[ 0 ] . ptr ) , FAILING_NUMBER ) ;
391
+ assert_eq ! ( Arc :: strong_count( & certs[ 0 ] . cert_handle ) , FAILING_NUMBER ) ;
400
392
401
393
let mut config = config:: Builder :: new ( ) ;
402
394
let err = config. set_default_chains ( certs. clone ( ) ) . err ( ) . unwrap ( ) ;
@@ -405,7 +397,7 @@ mod tests {
405
397
406
398
// The config should not hold a reference when the error was detected
407
399
// in the bindings
408
- assert_eq ! ( Arc :: strong_count( & certs[ 0 ] . ptr ) , FAILING_NUMBER ) ;
400
+ assert_eq ! ( Arc :: strong_count( & certs[ 0 ] . cert_handle ) , FAILING_NUMBER ) ;
409
401
410
402
Ok ( ( ) )
411
403
}
@@ -430,8 +422,8 @@ mod tests {
430
422
& test_pair. client. peer_cert_chain( ) . unwrap( )
431
423
) ) ;
432
424
433
- assert_eq ! ( Arc :: strong_count( & alligator_cert. ptr ) , 2 ) ;
434
- assert_eq ! ( Arc :: strong_count( & beaver_cert. ptr ) , 2 ) ;
425
+ assert_eq ! ( Arc :: strong_count( & alligator_cert. cert_handle ) , 2 ) ;
426
+ assert_eq ! ( Arc :: strong_count( & beaver_cert. cert_handle ) , 2 ) ;
435
427
}
436
428
437
429
// set an explicit default
@@ -449,10 +441,10 @@ mod tests {
449
441
& test_pair. client. peer_cert_chain( ) . unwrap( )
450
442
) ) ;
451
443
452
- assert_eq ! ( Arc :: strong_count( & alligator_cert. ptr ) , 2 ) ;
444
+ assert_eq ! ( Arc :: strong_count( & alligator_cert. cert_handle ) , 2 ) ;
453
445
// beaver has an additional reference because it was used in multiple
454
446
// calls
455
- assert_eq ! ( Arc :: strong_count( & beaver_cert. ptr ) , 3 ) ;
447
+ assert_eq ! ( Arc :: strong_count( & beaver_cert. cert_handle ) , 3 ) ;
456
448
}
457
449
458
450
// set a default without adding it to the store
@@ -470,8 +462,8 @@ mod tests {
470
462
& test_pair. client. peer_cert_chain( ) . unwrap( )
471
463
) ) ;
472
464
473
- assert_eq ! ( Arc :: strong_count( & alligator_cert. ptr ) , 2 ) ;
474
- assert_eq ! ( Arc :: strong_count( & beaver_cert. ptr ) , 2 ) ;
465
+ assert_eq ! ( Arc :: strong_count( & alligator_cert. cert_handle ) , 2 ) ;
466
+ assert_eq ! ( Arc :: strong_count( & beaver_cert. cert_handle ) , 2 ) ;
475
467
}
476
468
477
469
Ok ( ( ) )
0 commit comments