Skip to content

Commit 4ae43ec

Browse files
authored
fix(bindings): remove mutation behind Arc (#5124)
1 parent 2c47d43 commit 4ae43ec

File tree

2 files changed

+58
-63
lines changed

2 files changed

+58
-63
lines changed

bindings/rust/extended/s2n-tls/src/cert_chain.rs

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,41 @@ use std::{
1313
///
1414
/// [CertificateChain] is internally reference counted. The reference counted `T`
1515
/// must have a drop implementation.
16-
struct CertificateChainHandle {
17-
cert: NonNull<s2n_cert_chain_and_key>,
16+
pub(crate) struct CertificateChainHandle<'a> {
17+
pub(crate) cert: NonNull<s2n_cert_chain_and_key>,
1818
is_owned: bool,
19+
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
1920
}
2021

2122
// # Safety
2223
//
2324
// s2n_cert_chain_and_key objects can be sent across threads.
24-
unsafe impl Send for CertificateChainHandle {}
25-
unsafe impl Sync for CertificateChainHandle {}
25+
unsafe impl Send for CertificateChainHandle<'_> {}
26+
unsafe impl Sync for CertificateChainHandle<'_> {}
2627

27-
impl CertificateChainHandle {
28-
fn from_owned(cert: NonNull<s2n_cert_chain_and_key>) -> Self {
29-
Self {
30-
cert,
28+
impl CertificateChainHandle<'_> {
29+
/// Allocate an uninitialized CertificateChainHandle.
30+
///
31+
/// Corresponds to [s2n_cert_chain_and_key_new].
32+
pub(crate) fn allocate() -> Result<CertificateChainHandle<'static>, crate::error::Error> {
33+
crate::init::init();
34+
Ok(CertificateChainHandle {
35+
cert: unsafe { s2n_cert_chain_and_key_new().into_result() }?,
3136
is_owned: true,
32-
}
37+
_lifetime: PhantomData,
38+
})
3339
}
3440

3541
fn from_reference(cert: NonNull<s2n_cert_chain_and_key>) -> Self {
3642
Self {
3743
cert,
3844
is_owned: false,
45+
_lifetime: PhantomData,
3946
}
4047
}
4148
}
4249

43-
impl Drop for CertificateChainHandle {
50+
impl Drop for CertificateChainHandle<'_> {
4451
/// Corresponds to [s2n_cert_chain_and_key_free].
4552
fn drop(&mut self) {
4653
// ignore failures since there's not much we can do about it
@@ -53,13 +60,13 @@ impl Drop for CertificateChainHandle {
5360
}
5461

5562
pub struct Builder {
56-
cert: CertificateChain<'static>,
63+
cert_handle: CertificateChainHandle<'static>,
5764
}
5865

5966
impl Builder {
6067
pub fn new() -> Result<Self, Error> {
6168
Ok(Self {
62-
cert: CertificateChain::allocate_owned()?,
69+
cert_handle: CertificateChainHandle::allocate()?,
6370
})
6471
}
6572

@@ -73,7 +80,7 @@ impl Builder {
7380
// `private_key_pem` are not modified.
7481
// https://github.com/aws/s2n-tls/issues/4140
7582
s2n_cert_chain_and_key_load_pem_bytes(
76-
self.cert.as_mut_ptr(),
83+
self.cert_handle.cert.as_ptr(),
7784
chain.as_ptr() as *mut _,
7885
chain.len() as u32,
7986
key.as_ptr() as *mut _,
@@ -95,7 +102,7 @@ impl Builder {
95102
// is not modified
96103
// https://github.com/aws/s2n-tls/issues/4140
97104
s2n_cert_chain_and_key_load_public_pem_bytes(
98-
self.cert.as_mut_ptr(),
105+
self.cert_handle.cert.as_ptr(),
99106
chain.as_ptr() as *mut _,
100107
chain.len() as u32,
101108
)
@@ -109,7 +116,7 @@ impl Builder {
109116
pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> {
110117
unsafe {
111118
s2n_cert_chain_and_key_set_ocsp_data(
112-
self.cert.as_mut_ptr(),
119+
self.cert_handle.cert.as_ptr(),
113120
data.as_ptr(),
114121
data.len() as u32,
115122
)
@@ -122,7 +129,7 @@ impl Builder {
122129
pub fn build(self) -> Result<CertificateChain<'static>, Error> {
123130
// This method is currently infallible, but returning a result allows
124131
// us to add validation in the future.
125-
Ok(self.cert)
132+
Ok(CertificateChain::from_allocated(self.cert_handle))
126133
}
127134
}
128135

@@ -135,22 +142,16 @@ impl Builder {
135142
// safe to mutate CertificateChains.
136143
#[derive(Clone)]
137144
pub struct CertificateChain<'a> {
138-
ptr: Arc<CertificateChainHandle>,
139-
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
145+
cert_handle: Arc<CertificateChainHandle<'a>>,
140146
}
141147

142148
impl CertificateChain<'_> {
143-
/// This allocates a new certificate chain from s2n.
144-
///
145-
/// Corresponds to [s2n_cert_chain_and_key_new].
146-
pub(crate) fn allocate_owned() -> Result<CertificateChain<'static>, Error> {
147-
crate::init::init();
148-
unsafe {
149-
let ptr = s2n_cert_chain_and_key_new().into_result()?;
150-
Ok(CertificateChain {
151-
ptr: Arc::new(CertificateChainHandle::from_owned(ptr)),
152-
_lifetime: PhantomData,
153-
})
149+
/// Construct a CertificateChain from an allocated [CertificateChainHandle].
150+
pub(crate) fn from_allocated(
151+
handle: CertificateChainHandle<'static>,
152+
) -> CertificateChain<'static> {
153+
CertificateChain {
154+
cert_handle: Arc::new(handle),
154155
}
155156
}
156157

@@ -162,8 +163,7 @@ impl CertificateChain<'_> {
162163
let handle = Arc::new(CertificateChainHandle::from_reference(ptr));
163164

164165
CertificateChain {
165-
ptr: handle,
166-
_lifetime: PhantomData,
166+
cert_handle: handle,
167167
}
168168
}
169169

@@ -202,16 +202,8 @@ impl CertificateChain<'_> {
202202
self.len() == 0
203203
}
204204

205-
/// SAFETY: Only one instance of `CertificateChain` may exist when this method
206-
/// is called. s2n_cert_chain_and_key is not thread-safe, so it is not safe
207-
/// to mutate the certificate chain if references are held across multiple threads.
208-
pub(crate) unsafe fn as_mut_ptr(&mut self) -> *mut s2n_cert_chain_and_key {
209-
debug_assert_eq!(Arc::strong_count(&self.ptr), 1);
210-
self.ptr.cert.as_ptr()
211-
}
212-
213205
pub(crate) fn as_ptr(&self) -> *const s2n_cert_chain_and_key {
214-
self.ptr.cert.as_ptr() as *const _
206+
self.cert_handle.cert.as_ptr() as *const _
215207
}
216208
}
217209

@@ -339,28 +331,28 @@ mod tests {
339331
#[test]
340332
fn reference_count_increment() -> Result<(), crate::error::Error> {
341333
let cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain();
342-
assert_eq!(Arc::strong_count(&cert.ptr), 1);
334+
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
343335

344336
{
345337
let mut server = config::Builder::new();
346338
server.load_chain(cert.clone())?;
347339

348340
// after being added, the reference count should have increased
349-
assert_eq!(Arc::strong_count(&cert.ptr), 2);
341+
assert_eq!(Arc::strong_count(&cert.cert_handle), 2);
350342
}
351343

352344
// after the config goes out of scope and is dropped, the ref count should
353345
// decrement
354-
assert_eq!(Arc::strong_count(&cert.ptr), 1);
346+
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
355347
Ok(())
356348
}
357349

358350
#[test]
359351
fn cert_is_dropped() {
360352
let weak_ref = {
361353
let cert = SniTestCerts::AlligatorEcdsa.get().into_certificate_chain();
362-
assert_eq!(Arc::strong_count(&cert.ptr), 1);
363-
Arc::downgrade(&cert.ptr)
354+
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
355+
Arc::downgrade(&cert.cert_handle)
364356
};
365357
assert_eq!(weak_ref.strong_count(), 0);
366358
assert!(weak_ref.upgrade().is_none());
@@ -377,17 +369,17 @@ mod tests {
377369
let mut test_pair_2 =
378370
sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?;
379371

380-
assert_eq!(Arc::strong_count(&cert.ptr), 3);
372+
assert_eq!(Arc::strong_count(&cert.cert_handle), 3);
381373

382374
assert!(test_pair_1.handshake().is_ok());
383375
assert!(test_pair_2.handshake().is_ok());
384376

385-
assert_eq!(Arc::strong_count(&cert.ptr), 3);
377+
assert_eq!(Arc::strong_count(&cert.cert_handle), 3);
386378

387379
drop(test_pair_1);
388-
assert_eq!(Arc::strong_count(&cert.ptr), 2);
380+
assert_eq!(Arc::strong_count(&cert.cert_handle), 2);
389381
drop(test_pair_2);
390-
assert_eq!(Arc::strong_count(&cert.ptr), 1);
382+
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
391383
Ok(())
392384
}
393385

@@ -396,7 +388,7 @@ mod tests {
396388
// 5 certs in the maximum allowed, 6 should error.
397389
const FAILING_NUMBER: usize = 6;
398390
let certs = vec![SniTestCerts::AlligatorRsa.get().into_certificate_chain(); FAILING_NUMBER];
399-
assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER);
391+
assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER);
400392

401393
let mut config = config::Builder::new();
402394
let err = config.set_default_chains(certs.clone()).err().unwrap();
@@ -405,7 +397,7 @@ mod tests {
405397

406398
// The config should not hold a reference when the error was detected
407399
// in the bindings
408-
assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER);
400+
assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER);
409401

410402
Ok(())
411403
}
@@ -430,8 +422,8 @@ mod tests {
430422
&test_pair.client.peer_cert_chain().unwrap()
431423
));
432424

433-
assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
434-
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2);
425+
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
426+
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2);
435427
}
436428

437429
// set an explicit default
@@ -449,10 +441,10 @@ mod tests {
449441
&test_pair.client.peer_cert_chain().unwrap()
450442
));
451443

452-
assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
444+
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
453445
// beaver has an additional reference because it was used in multiple
454446
// calls
455-
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 3);
447+
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 3);
456448
}
457449

458450
// set a default without adding it to the store
@@ -470,8 +462,8 @@ mod tests {
470462
&test_pair.client.peer_cert_chain().unwrap()
471463
));
472464

473-
assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
474-
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2);
465+
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
466+
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2);
475467
}
476468

477469
Ok(())

bindings/rust/extended/s2n-tls/src/connection.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use crate::renegotiate::RenegotiateState;
88
use crate::{
99
callbacks::*,
10-
cert_chain::CertificateChain,
10+
cert_chain::{CertificateChain, CertificateChainHandle},
1111
config::Config,
1212
enums::*,
1313
error::{Error, Fallible, Pollable},
@@ -1219,11 +1219,14 @@ impl Connection {
12191219
/// Corresponds to [s2n_connection_get_peer_cert_chain].
12201220
pub fn peer_cert_chain(&self) -> Result<CertificateChain<'static>, Error> {
12211221
unsafe {
1222-
let mut chain = CertificateChain::allocate_owned()?;
1223-
s2n_connection_get_peer_cert_chain(self.connection.as_ptr(), chain.as_mut_ptr())
1224-
.into_result()
1225-
.map(|_| ())?;
1226-
Ok(chain)
1222+
let chain_handle = CertificateChainHandle::allocate()?;
1223+
s2n_connection_get_peer_cert_chain(
1224+
self.connection.as_ptr(),
1225+
chain_handle.cert.as_ptr(),
1226+
)
1227+
.into_result()
1228+
.map(|_| ())?;
1229+
Ok(CertificateChain::from_allocated(chain_handle))
12271230
}
12281231
}
12291232

0 commit comments

Comments
 (0)