Skip to content

Commit 34f5e4b

Browse files
authored
fix(s2n-quic-transport): Fixes s2n-quic TLS state incongruity when s2n-tls returns pending frequently (#2574)
1 parent 9ed7d73 commit 34f5e4b

File tree

6 files changed

+303
-0
lines changed

6 files changed

+303
-0
lines changed

quic/s2n-quic-core/src/crypto/tls.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ pub mod testing;
1717
#[cfg(all(feature = "alloc", any(test, feature = "testing")))]
1818
pub mod null;
1919

20+
#[cfg(feature = "alloc")]
21+
pub mod slow_tls;
22+
2023
/// Holds all application parameters which are exchanged within the TLS handshake.
2124
#[derive(Debug)]
2225
pub struct ApplicationParameters<'a> {
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
use crate::{
4+
application,
5+
crypto::{tls, CryptoSuite},
6+
transport,
7+
};
8+
use alloc::{boxed::Box, vec::Vec};
9+
use core::{any::Any, task::Poll};
10+
11+
const DEFER_COUNT: u8 = 3;
12+
13+
pub struct SlowEndpoint<E: tls::Endpoint> {
14+
endpoint: E,
15+
}
16+
17+
impl<E: tls::Endpoint> SlowEndpoint<E> {
18+
pub fn new(endpoint: E) -> Self {
19+
SlowEndpoint { endpoint }
20+
}
21+
}
22+
23+
impl<E: tls::Endpoint> tls::Endpoint for SlowEndpoint<E> {
24+
type Session = SlowSession<<E as tls::Endpoint>::Session>;
25+
26+
fn new_server_session<Params: s2n_codec::EncoderValue>(
27+
&mut self,
28+
transport_parameters: &Params,
29+
) -> Self::Session {
30+
let inner_session = self.endpoint.new_server_session(transport_parameters);
31+
SlowSession {
32+
defer: DEFER_COUNT,
33+
inner_session,
34+
}
35+
}
36+
37+
fn new_client_session<Params: s2n_codec::EncoderValue>(
38+
&mut self,
39+
transport_parameters: &Params,
40+
server_name: application::ServerName,
41+
) -> Self::Session {
42+
let inner_session = self
43+
.endpoint
44+
.new_client_session(transport_parameters, server_name);
45+
SlowSession {
46+
defer: DEFER_COUNT,
47+
inner_session,
48+
}
49+
}
50+
51+
fn max_tag_length(&self) -> usize {
52+
self.endpoint.max_tag_length()
53+
}
54+
}
55+
56+
// SlowSession is a test TLS provider that is slow, namely, for each call to poll,
57+
// it returns Poll::Pending several times before actually polling the real TLS library.
58+
// This is used in an integration test to assert that our code is correct in the event
59+
// of any random pendings/wakeups that might occur when negotiating TLS.
60+
#[derive(Debug)]
61+
pub struct SlowSession<S: tls::Session> {
62+
defer: u8,
63+
inner_session: S,
64+
}
65+
66+
impl<S: tls::Session> tls::Session for SlowSession<S> {
67+
#[inline]
68+
fn poll<W>(&mut self, context: &mut W) -> Poll<Result<(), transport::Error>>
69+
where
70+
W: tls::Context<Self>,
71+
{
72+
// Self-wake and return Pending if defer is non-zero
73+
if let Some(d) = self.defer.checked_sub(1) {
74+
self.defer = d;
75+
context.waker().wake_by_ref();
76+
return Poll::Pending;
77+
}
78+
79+
// Otherwise we'll call the function to actually make progress
80+
// in the TLS handshake and set up to defer again the next time
81+
// we're here.
82+
self.defer = DEFER_COUNT;
83+
self.inner_session.poll(&mut SlowContext(context))
84+
}
85+
}
86+
87+
impl<S: tls::Session> CryptoSuite for SlowSession<S> {
88+
type HandshakeKey = <S as CryptoSuite>::HandshakeKey;
89+
type HandshakeHeaderKey = <S as CryptoSuite>::HandshakeHeaderKey;
90+
type InitialKey = <S as CryptoSuite>::InitialKey;
91+
type InitialHeaderKey = <S as CryptoSuite>::InitialHeaderKey;
92+
type ZeroRttKey = <S as CryptoSuite>::ZeroRttKey;
93+
type ZeroRttHeaderKey = <S as CryptoSuite>::ZeroRttHeaderKey;
94+
type OneRttKey = <S as CryptoSuite>::OneRttKey;
95+
type OneRttHeaderKey = <S as CryptoSuite>::OneRttHeaderKey;
96+
type RetryKey = <S as CryptoSuite>::RetryKey;
97+
}
98+
99+
struct SlowContext<'a, Inner>(&'a mut Inner);
100+
101+
impl<I, S: tls::Session> tls::Context<S> for SlowContext<'_, I>
102+
where
103+
I: tls::Context<SlowSession<S>>,
104+
{
105+
fn on_client_application_params(
106+
&mut self,
107+
client_params: tls::ApplicationParameters,
108+
server_params: &mut Vec<u8>,
109+
) -> Result<(), transport::Error> {
110+
self.0
111+
.on_client_application_params(client_params, server_params)
112+
}
113+
114+
fn on_handshake_keys(
115+
&mut self,
116+
key: <S as CryptoSuite>::HandshakeKey,
117+
header_key: <S as CryptoSuite>::HandshakeHeaderKey,
118+
) -> Result<(), transport::Error> {
119+
self.0.on_handshake_keys(key, header_key)
120+
}
121+
122+
fn on_zero_rtt_keys(
123+
&mut self,
124+
key: <S>::ZeroRttKey,
125+
header_key: <S>::ZeroRttHeaderKey,
126+
application_parameters: tls::ApplicationParameters,
127+
) -> Result<(), transport::Error> {
128+
self.0
129+
.on_zero_rtt_keys(key, header_key, application_parameters)
130+
}
131+
132+
fn on_one_rtt_keys(
133+
&mut self,
134+
key: <S>::OneRttKey,
135+
header_key: <S>::OneRttHeaderKey,
136+
application_parameters: tls::ApplicationParameters,
137+
) -> Result<(), transport::Error> {
138+
self.0
139+
.on_one_rtt_keys(key, header_key, application_parameters)
140+
}
141+
142+
fn on_server_name(
143+
&mut self,
144+
server_name: application::ServerName,
145+
) -> Result<(), transport::Error> {
146+
self.0.on_server_name(server_name)
147+
}
148+
149+
fn on_application_protocol(
150+
&mut self,
151+
application_protocol: tls::Bytes,
152+
) -> Result<(), transport::Error> {
153+
self.0.on_application_protocol(application_protocol)
154+
}
155+
156+
fn on_handshake_complete(&mut self) -> Result<(), transport::Error> {
157+
self.0.on_handshake_complete()
158+
}
159+
160+
fn on_tls_exporter_ready(
161+
&mut self,
162+
session: &impl tls::TlsSession,
163+
) -> Result<(), transport::Error> {
164+
self.0.on_tls_exporter_ready(session)
165+
}
166+
167+
fn receive_initial(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
168+
self.0.receive_initial(max_len)
169+
}
170+
171+
fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
172+
self.0.receive_handshake(max_len)
173+
}
174+
175+
fn receive_application(&mut self, max_len: Option<usize>) -> Option<tls::Bytes> {
176+
self.0.receive_application(max_len)
177+
}
178+
179+
fn can_send_initial(&self) -> bool {
180+
self.0.can_send_initial()
181+
}
182+
183+
fn send_initial(&mut self, transmission: tls::Bytes) {
184+
self.0.send_initial(transmission);
185+
}
186+
187+
fn can_send_handshake(&self) -> bool {
188+
self.0.can_send_handshake()
189+
}
190+
191+
fn send_handshake(&mut self, transmission: tls::Bytes) {
192+
self.0.send_handshake(transmission);
193+
}
194+
195+
fn can_send_application(&self) -> bool {
196+
self.0.can_send_application()
197+
}
198+
199+
fn send_application(&mut self, transmission: tls::Bytes) {
200+
self.0.send_application(transmission)
201+
}
202+
203+
fn waker(&self) -> &core::task::Waker {
204+
self.0.waker()
205+
}
206+
207+
fn on_key_exchange_group(
208+
&mut self,
209+
named_group: tls::NamedGroup,
210+
) -> Result<(), transport::Error> {
211+
self.0.on_key_exchange_group(named_group)
212+
}
213+
214+
fn on_tls_context(&mut self, context: Box<dyn Any + Send>) {
215+
self.0.on_tls_context(context)
216+
}
217+
}

quic/s2n-quic-transport/src/connection/connection_impl.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,16 @@ impl<Config: endpoint::Config> connection::Trait for ConnectionImpl<Config> {
11411141
// check if crypto progress can be made
11421142
self.update_crypto_state(timestamp, subscriber, datagram, dc, conn_limits)?;
11431143

1144+
if self.space_manager.handshake().is_some() && self.space_manager.is_handshake_confirmed() {
1145+
let mut publisher = self.event_context.publisher(timestamp, subscriber);
1146+
1147+
//= https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2
1148+
//# An endpoint MUST discard its handshake keys when the TLS handshake is
1149+
//# confirmed (Section 4.1.2).
1150+
self.space_manager
1151+
.discard_handshake(&mut self.path_manager, &mut publisher);
1152+
}
1153+
11441154
// return an error if the application set one
11451155
self.error?;
11461156

quic/s2n-quic/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ mod recorder;
2727
mod connection_limits;
2828
mod resumption;
2929
mod setup;
30+
mod slow_tls;
3031
use setup::*;
3132

3233
mod blackhole;

quic/s2n-quic/src/tests/setup.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,28 @@ mod mtls {
242242
}
243243
}
244244

245+
mod slow_tls {
246+
use crate::provider::tls::Provider;
247+
use s2n_quic_core::crypto::tls::{slow_tls::SlowEndpoint, Endpoint};
248+
pub struct SlowTlsProvider<E: Endpoint> {
249+
pub endpoint: E,
250+
}
251+
252+
impl<E: Endpoint> Provider for SlowTlsProvider<E> {
253+
type Server = SlowEndpoint<E>;
254+
type Client = SlowEndpoint<E>;
255+
type Error = String;
256+
257+
fn start_server(self) -> Result<Self::Server, Self::Error> {
258+
Ok(SlowEndpoint::new(self.endpoint))
259+
}
260+
261+
fn start_client(self) -> Result<Self::Client, Self::Error> {
262+
Ok(SlowEndpoint::new(self.endpoint))
263+
}
264+
}
265+
}
266+
245267
#[cfg(feature = "s2n-quic-tls")]
246268
mod resumption {
247269
use super::*;
@@ -326,3 +348,6 @@ pub use mtls::*;
326348

327349
#[cfg(feature = "s2n-quic-tls")]
328350
pub use resumption::*;
351+
352+
#[cfg(not(feature = "provider-tls-fips"))]
353+
pub use slow_tls::SlowTlsProvider;

quic/s2n-quic/src/tests/slow_tls.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#[test]
5+
#[cfg(not(feature = "provider-tls-fips"))]
6+
fn slow_default_tls() {
7+
use super::*;
8+
use crate::provider::tls::default;
9+
use s2n_quic_core::crypto::tls::testing::certificates::{CERT_PEM, KEY_PEM};
10+
11+
let model = Model::default();
12+
13+
let server_endpoint = default::Server::builder()
14+
.with_certificate(CERT_PEM, KEY_PEM)
15+
.unwrap()
16+
.build()
17+
.unwrap();
18+
let slow_server = SlowTlsProvider {
19+
endpoint: server_endpoint,
20+
};
21+
22+
let client_endpoint = default::Client::builder()
23+
.with_certificate(CERT_PEM)
24+
.unwrap()
25+
.build()
26+
.unwrap();
27+
let slow_client = SlowTlsProvider {
28+
endpoint: client_endpoint,
29+
};
30+
31+
test(model, |handle| {
32+
let server = Server::builder()
33+
.with_io(handle.builder().build()?)?
34+
.with_tls(slow_server)?
35+
.start()?;
36+
37+
let client = Client::builder()
38+
.with_io(handle.builder().build().unwrap())?
39+
.with_tls(slow_client)?
40+
.start()?;
41+
let addr = start_server(server)?;
42+
start_client(client, addr, Data::new(1000))?;
43+
44+
Ok(addr)
45+
})
46+
.unwrap();
47+
}

0 commit comments

Comments
 (0)