Skip to content

Commit 711ee0d

Browse files
jmayclingoatgoose
andauthored
feat(bindings): expose context on cert chain (#5132)
Co-authored-by: Sam Clark <[email protected]>
1 parent ac1d098 commit 711ee0d

File tree

1 file changed

+160
-4
lines changed

1 file changed

+160
-4
lines changed

bindings/rust/extended/s2n-tls/src/cert_chain.rs

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4-
use crate::error::{Error, Fallible};
4+
use crate::error::{Error, ErrorType, Fallible};
55
use s2n_tls_sys::*;
66
use std::{
7+
any::Any,
8+
ffi::c_void,
79
marker::PhantomData,
810
ptr::{self, NonNull},
911
sync::Arc,
@@ -13,6 +15,7 @@ use std::{
1315
///
1416
/// [CertificateChain] is internally reference counted. The reference counted `T`
1517
/// must have a drop implementation.
18+
#[derive(Debug)]
1619
pub(crate) struct CertificateChainHandle<'a> {
1720
pub(crate) cert: NonNull<s2n_cert_chain_and_key>,
1821
is_owned: bool,
@@ -45,20 +48,57 @@ impl CertificateChainHandle<'_> {
4548
_lifetime: PhantomData,
4649
}
4750
}
51+
52+
/// Corresponds to [s2n_cert_chain_and_key_get_ctx].
53+
fn context_mut(&mut self) -> Option<&mut Context> {
54+
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
55+
if context.is_null() {
56+
None
57+
} else {
58+
Some(unsafe { &mut *(context as *mut Context) })
59+
}
60+
}
61+
62+
/// Corresponds to [s2n_cert_chain_and_key_get_ctx].
63+
fn context(&self) -> Option<&Context> {
64+
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
65+
if context.is_null() {
66+
None
67+
} else {
68+
Some(unsafe { &*(context as *const Context) })
69+
}
70+
}
4871
}
4972

5073
impl Drop for CertificateChainHandle<'_> {
5174
/// Corresponds to [s2n_cert_chain_and_key_free].
5275
fn drop(&mut self) {
53-
// ignore failures since there's not much we can do about it
5476
if self.is_owned {
77+
if let Some(internal_context) = self.context_mut() {
78+
drop(unsafe { Box::from_raw(internal_context) });
79+
}
80+
// ignore failures since there's not much we can do about it
5581
unsafe {
82+
// null the cert chain context out of an abundance of caution
83+
let _ = s2n_cert_chain_and_key_set_ctx(self.cert.as_ptr(), std::ptr::null_mut())
84+
.into_result();
85+
5686
let _ = s2n_cert_chain_and_key_free(self.cert.as_ptr()).into_result();
5787
}
5888
}
5989
}
6090
}
6191

92+
/// An internal container to hold the customer supplied application context.
93+
///
94+
/// We can't directly store the application context on the `s2n_cert_chain_and_key`,
95+
/// because `*mut dyn Any` is a fat pointer (16 bytes) and can not be stored as
96+
/// a c_void (8 bytes).
97+
struct Context {
98+
application_context: Box<dyn Any + Send + Sync>,
99+
}
100+
101+
#[derive(Debug)]
62102
pub struct Builder {
63103
cert_handle: CertificateChainHandle<'static>,
64104
}
@@ -125,6 +165,39 @@ impl Builder {
125165
Ok(self)
126166
}
127167

168+
/// Associates an arbitrary application context with the CertificateChain to
169+
/// be later retrieved via [`CertificateChain::application_context()`].
170+
///
171+
/// This API will override an existing application context set on the Builder.
172+
///
173+
/// Corresponds to [s2n_cert_chain_and_key_set_ctx].
174+
pub fn set_application_context<T: Send + Sync + 'static>(
175+
&mut self,
176+
app_context: T,
177+
) -> Result<&mut Self, Error> {
178+
match self.cert_handle.context_mut() {
179+
Some(_) => Err(Error::bindings(
180+
ErrorType::UsageError,
181+
"cert builder error",
182+
"set_application_context can only be called once",
183+
)),
184+
None => {
185+
let app_context = Box::new(app_context);
186+
let internal_context = Box::new(Context {
187+
application_context: app_context,
188+
});
189+
unsafe {
190+
s2n_cert_chain_and_key_set_ctx(
191+
self.cert_handle.cert.as_ptr(),
192+
Box::into_raw(internal_context) as *mut c_void,
193+
)
194+
.into_result()
195+
}?;
196+
Ok(self)
197+
}
198+
}
199+
}
200+
128201
/// Return an immutable, internally-reference counted CertificateChain.
129202
pub fn build(self) -> Result<CertificateChain<'static>, Error> {
130203
// This method is currently infallible, but returning a result allows
@@ -177,6 +250,23 @@ impl CertificateChain<'_> {
177250
}
178251
}
179252

253+
/// Retrieves a reference to the application context associated with the
254+
/// CertificateChain.
255+
///
256+
/// If an application context hasn't been set on the CertificateChain or if
257+
/// the set application context isn't of type `T`, `None` will be returned.
258+
///
259+
/// To set a context on the connection, use [`Builder::set_application_context()`].
260+
///
261+
/// Corresponds to [s2n_cert_chain_and_key_get_ctx].
262+
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
263+
if let Some(internal_context) = self.cert_handle.context() {
264+
internal_context.application_context.downcast_ref()
265+
} else {
266+
None
267+
}
268+
}
269+
180270
/// Return the length of this certificate chain.
181271
///
182272
/// Note that the underlying API currently traverses a linked list, so this is a relatively
@@ -273,9 +363,12 @@ unsafe impl Send for Certificate<'_> {}
273363
mod tests {
274364
use crate::{
275365
config,
276-
error::{ErrorSource, ErrorType},
366+
error::{Error as S2NError, ErrorSource, ErrorType},
277367
security::DEFAULT_TLS13,
278-
testing::{InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair},
368+
testing::{
369+
config_builder, CertKeyPair, InsecureAcceptAllCertificatesHandler, SniTestCerts,
370+
TestPair,
371+
},
279372
};
280373

281374
use super::*;
@@ -495,4 +588,67 @@ mod tests {
495588
fn assert_send_sync<T: 'static + Send + Sync>() {}
496589
assert_send_sync::<CertificateChain<'static>>();
497590
}
591+
592+
/// sanity check for basic cert chain context interactions
593+
#[test]
594+
fn application_context_workflow() -> Result<(), S2NError> {
595+
let context: Arc<u64> = Arc::new(0xC0FFEE);
596+
let handle = Arc::clone(&context);
597+
assert_eq!(Arc::strong_count(&handle), 2);
598+
599+
let default = CertKeyPair::default();
600+
let mut chain = Builder::new()?;
601+
chain.load_pem(default.cert(), default.key())?;
602+
chain.set_application_context(context)?;
603+
let chain = chain.build()?;
604+
605+
let invalid_type_get = chain.application_context::<u64>();
606+
assert!(invalid_type_get.is_none());
607+
608+
let retrieved_context = chain.application_context::<Arc<u64>>().unwrap();
609+
assert_eq!(*retrieved_context.as_ref(), 0xC0FFEE);
610+
assert_eq!(Arc::strong_count(&handle), 2);
611+
drop(chain);
612+
assert_eq!(Arc::strong_count(&handle), 1);
613+
Ok(())
614+
}
615+
616+
/// When an application context is overridden, it should be error.
617+
#[test]
618+
fn application_context_override() -> Result<(), S2NError> {
619+
let initial: Arc<u64> = Arc::new(0xC0FFEE);
620+
let overridden: Arc<[u8; 6]> = Arc::new(*b"coffee");
621+
622+
let mut builder = Builder::new()?;
623+
builder.set_application_context(initial)?;
624+
let err = builder.set_application_context(overridden).unwrap_err();
625+
assert_eq!(err.kind(), ErrorType::UsageError);
626+
627+
Ok(())
628+
}
629+
630+
/// An application context should be retrievable from a selected cert after
631+
/// the handshake.
632+
#[test]
633+
fn application_context_from_selected_cert() -> Result<(), S2NError> {
634+
let default = CertKeyPair::default();
635+
let mut chain = Builder::new()?;
636+
chain.load_pem(default.cert(), default.key())?;
637+
chain.set_application_context(0xC0FFEE_u64)?;
638+
639+
let mut server_config = config::Builder::new();
640+
server_config.load_chain(chain.build()?)?;
641+
642+
let client_config = config_builder(&crate::security::DEFAULT).unwrap();
643+
644+
let mut test_pair =
645+
TestPair::from_configs(&client_config.build()?, &server_config.build()?);
646+
test_pair.handshake()?;
647+
648+
let selected_cert = test_pair.server.selected_cert().unwrap();
649+
let context = selected_cert.application_context::<u64>();
650+
assert_eq!(context, Some(&0xC0FFEE_u64));
651+
652+
Ok(())
653+
}
498654
}

0 commit comments

Comments
 (0)