1
1
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
// SPDX-License-Identifier: Apache-2.0
3
3
4
- use crate :: error:: { Error , Fallible } ;
4
+ use crate :: error:: { Error , ErrorType , Fallible } ;
5
5
use s2n_tls_sys:: * ;
6
6
use std:: {
7
+ any:: Any ,
8
+ ffi:: c_void,
7
9
marker:: PhantomData ,
8
10
ptr:: { self , NonNull } ,
9
11
sync:: Arc ,
@@ -13,6 +15,7 @@ use std::{
13
15
///
14
16
/// [CertificateChain] is internally reference counted. The reference counted `T`
15
17
/// must have a drop implementation.
18
+ #[ derive( Debug ) ]
16
19
pub ( crate ) struct CertificateChainHandle < ' a > {
17
20
pub ( crate ) cert : NonNull < s2n_cert_chain_and_key > ,
18
21
is_owned : bool ,
@@ -45,20 +48,57 @@ impl CertificateChainHandle<'_> {
45
48
_lifetime : PhantomData ,
46
49
}
47
50
}
51
+
52
+ /// Corresponds to [s2n_cert_chain_and_key_get_ctx].
53
+ fn context_mut ( & mut self ) -> Option < & mut Context > {
54
+ let context = unsafe { s2n_cert_chain_and_key_get_ctx ( self . cert . as_ptr ( ) ) } ;
55
+ if context. is_null ( ) {
56
+ None
57
+ } else {
58
+ Some ( unsafe { & mut * ( context as * mut Context ) } )
59
+ }
60
+ }
61
+
62
+ /// Corresponds to [s2n_cert_chain_and_key_get_ctx].
63
+ fn context ( & self ) -> Option < & Context > {
64
+ let context = unsafe { s2n_cert_chain_and_key_get_ctx ( self . cert . as_ptr ( ) ) } ;
65
+ if context. is_null ( ) {
66
+ None
67
+ } else {
68
+ Some ( unsafe { & * ( context as * const Context ) } )
69
+ }
70
+ }
48
71
}
49
72
50
73
impl Drop for CertificateChainHandle < ' _ > {
51
74
/// Corresponds to [s2n_cert_chain_and_key_free].
52
75
fn drop ( & mut self ) {
53
- // ignore failures since there's not much we can do about it
54
76
if self . is_owned {
77
+ if let Some ( internal_context) = self . context_mut ( ) {
78
+ drop ( unsafe { Box :: from_raw ( internal_context) } ) ;
79
+ }
80
+ // ignore failures since there's not much we can do about it
55
81
unsafe {
82
+ // null the cert chain context out of an abundance of caution
83
+ let _ = s2n_cert_chain_and_key_set_ctx ( self . cert . as_ptr ( ) , std:: ptr:: null_mut ( ) )
84
+ . into_result ( ) ;
85
+
56
86
let _ = s2n_cert_chain_and_key_free ( self . cert . as_ptr ( ) ) . into_result ( ) ;
57
87
}
58
88
}
59
89
}
60
90
}
61
91
92
+ /// An internal container to hold the customer supplied application context.
93
+ ///
94
+ /// We can't directly store the application context on the `s2n_cert_chain_and_key`,
95
+ /// because `*mut dyn Any` is a fat pointer (16 bytes) and can not be stored as
96
+ /// a c_void (8 bytes).
97
+ struct Context {
98
+ application_context : Box < dyn Any + Send + Sync > ,
99
+ }
100
+
101
+ #[ derive( Debug ) ]
62
102
pub struct Builder {
63
103
cert_handle : CertificateChainHandle < ' static > ,
64
104
}
@@ -125,6 +165,39 @@ impl Builder {
125
165
Ok ( self )
126
166
}
127
167
168
+ /// Associates an arbitrary application context with the CertificateChain to
169
+ /// be later retrieved via [`CertificateChain::application_context()`].
170
+ ///
171
+ /// This API will override an existing application context set on the Builder.
172
+ ///
173
+ /// Corresponds to [s2n_cert_chain_and_key_set_ctx].
174
+ pub fn set_application_context < T : Send + Sync + ' static > (
175
+ & mut self ,
176
+ app_context : T ,
177
+ ) -> Result < & mut Self , Error > {
178
+ match self . cert_handle . context_mut ( ) {
179
+ Some ( _) => Err ( Error :: bindings (
180
+ ErrorType :: UsageError ,
181
+ "cert builder error" ,
182
+ "set_application_context can only be called once" ,
183
+ ) ) ,
184
+ None => {
185
+ let app_context = Box :: new ( app_context) ;
186
+ let internal_context = Box :: new ( Context {
187
+ application_context : app_context,
188
+ } ) ;
189
+ unsafe {
190
+ s2n_cert_chain_and_key_set_ctx (
191
+ self . cert_handle . cert . as_ptr ( ) ,
192
+ Box :: into_raw ( internal_context) as * mut c_void ,
193
+ )
194
+ . into_result ( )
195
+ } ?;
196
+ Ok ( self )
197
+ }
198
+ }
199
+ }
200
+
128
201
/// Return an immutable, internally-reference counted CertificateChain.
129
202
pub fn build ( self ) -> Result < CertificateChain < ' static > , Error > {
130
203
// This method is currently infallible, but returning a result allows
@@ -177,6 +250,23 @@ impl CertificateChain<'_> {
177
250
}
178
251
}
179
252
253
+ /// Retrieves a reference to the application context associated with the
254
+ /// CertificateChain.
255
+ ///
256
+ /// If an application context hasn't been set on the CertificateChain or if
257
+ /// the set application context isn't of type `T`, `None` will be returned.
258
+ ///
259
+ /// To set a context on the connection, use [`Builder::set_application_context()`].
260
+ ///
261
+ /// Corresponds to [s2n_cert_chain_and_key_get_ctx].
262
+ pub fn application_context < T : Send + Sync + ' static > ( & self ) -> Option < & T > {
263
+ if let Some ( internal_context) = self . cert_handle . context ( ) {
264
+ internal_context. application_context . downcast_ref ( )
265
+ } else {
266
+ None
267
+ }
268
+ }
269
+
180
270
/// Return the length of this certificate chain.
181
271
///
182
272
/// Note that the underlying API currently traverses a linked list, so this is a relatively
@@ -273,9 +363,12 @@ unsafe impl Send for Certificate<'_> {}
273
363
mod tests {
274
364
use crate :: {
275
365
config,
276
- error:: { ErrorSource , ErrorType } ,
366
+ error:: { Error as S2NError , ErrorSource , ErrorType } ,
277
367
security:: DEFAULT_TLS13 ,
278
- testing:: { InsecureAcceptAllCertificatesHandler , SniTestCerts , TestPair } ,
368
+ testing:: {
369
+ config_builder, CertKeyPair , InsecureAcceptAllCertificatesHandler , SniTestCerts ,
370
+ TestPair ,
371
+ } ,
279
372
} ;
280
373
281
374
use super :: * ;
@@ -495,4 +588,67 @@ mod tests {
495
588
fn assert_send_sync < T : ' static + Send + Sync > ( ) { }
496
589
assert_send_sync :: < CertificateChain < ' static > > ( ) ;
497
590
}
591
+
592
+ /// sanity check for basic cert chain context interactions
593
+ #[ test]
594
+ fn application_context_workflow ( ) -> Result < ( ) , S2NError > {
595
+ let context: Arc < u64 > = Arc :: new ( 0xC0FFEE ) ;
596
+ let handle = Arc :: clone ( & context) ;
597
+ assert_eq ! ( Arc :: strong_count( & handle) , 2 ) ;
598
+
599
+ let default = CertKeyPair :: default ( ) ;
600
+ let mut chain = Builder :: new ( ) ?;
601
+ chain. load_pem ( default. cert ( ) , default. key ( ) ) ?;
602
+ chain. set_application_context ( context) ?;
603
+ let chain = chain. build ( ) ?;
604
+
605
+ let invalid_type_get = chain. application_context :: < u64 > ( ) ;
606
+ assert ! ( invalid_type_get. is_none( ) ) ;
607
+
608
+ let retrieved_context = chain. application_context :: < Arc < u64 > > ( ) . unwrap ( ) ;
609
+ assert_eq ! ( * retrieved_context. as_ref( ) , 0xC0FFEE ) ;
610
+ assert_eq ! ( Arc :: strong_count( & handle) , 2 ) ;
611
+ drop ( chain) ;
612
+ assert_eq ! ( Arc :: strong_count( & handle) , 1 ) ;
613
+ Ok ( ( ) )
614
+ }
615
+
616
+ /// When an application context is overridden, it should be error.
617
+ #[ test]
618
+ fn application_context_override ( ) -> Result < ( ) , S2NError > {
619
+ let initial: Arc < u64 > = Arc :: new ( 0xC0FFEE ) ;
620
+ let overridden: Arc < [ u8 ; 6 ] > = Arc :: new ( * b"coffee" ) ;
621
+
622
+ let mut builder = Builder :: new ( ) ?;
623
+ builder. set_application_context ( initial) ?;
624
+ let err = builder. set_application_context ( overridden) . unwrap_err ( ) ;
625
+ assert_eq ! ( err. kind( ) , ErrorType :: UsageError ) ;
626
+
627
+ Ok ( ( ) )
628
+ }
629
+
630
+ /// An application context should be retrievable from a selected cert after
631
+ /// the handshake.
632
+ #[ test]
633
+ fn application_context_from_selected_cert ( ) -> Result < ( ) , S2NError > {
634
+ let default = CertKeyPair :: default ( ) ;
635
+ let mut chain = Builder :: new ( ) ?;
636
+ chain. load_pem ( default. cert ( ) , default. key ( ) ) ?;
637
+ chain. set_application_context ( 0xC0FFEE_u64 ) ?;
638
+
639
+ let mut server_config = config:: Builder :: new ( ) ;
640
+ server_config. load_chain ( chain. build ( ) ?) ?;
641
+
642
+ let client_config = config_builder ( & crate :: security:: DEFAULT ) . unwrap ( ) ;
643
+
644
+ let mut test_pair =
645
+ TestPair :: from_configs ( & client_config. build ( ) ?, & server_config. build ( ) ?) ;
646
+ test_pair. handshake ( ) ?;
647
+
648
+ let selected_cert = test_pair. server . selected_cert ( ) . unwrap ( ) ;
649
+ let context = selected_cert. application_context :: < u64 > ( ) ;
650
+ assert_eq ! ( context, Some ( & 0xC0FFEE_u64 ) ) ;
651
+
652
+ Ok ( ( ) )
653
+ }
498
654
}
0 commit comments