Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions dc/s2n-quic-dc/src/psk/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::io::{self, HandshakeFailed};
use crate::path::secret;
use s2n_quic::{
provider::{event::Subscriber as Sub, tls::Provider as Prov},
server::Name,
Connection,
};
use std::{net::SocketAddr, sync::Arc, time::Duration};
Expand Down Expand Up @@ -101,6 +102,7 @@ impl Provider {
subscriber: Subscriber,
query_event_callback: fn(&mut Connection, Duration),
builder: Builder<Event>,
server_name: Name,
) -> io::Result<Self> {
let state = State::new_runtime(
addr,
Expand All @@ -117,11 +119,13 @@ impl Provider {
if let Some(state) = weak.upgrade() {
let runtime = state.runtime.as_ref().map(|v| &v.0).unwrap();
let client = state.client.clone();
let server_name = server_name.clone();
// Drop the JoinHandle -- we're not actually going to block on the join handle's
// result. The future will keep running in the background.
runtime.spawn(async move {
if let Err(HandshakeFailed { .. }) =
client.connect(peer, query_event_callback).await
if let Err(HandshakeFailed { .. }) = client
.connect(peer, query_event_callback, server_name)
.await
{
// failure has already been logged, no further action required.
}
Expand All @@ -140,9 +144,10 @@ impl Provider {
&self,
peer: SocketAddr,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> std::io::Result<HandshakeKind> {
let (_peer, kind) = self
.handshake_with_entry(peer, query_event_callback)
.handshake_with_entry(peer, query_event_callback, server_name)
.await?;
Ok(kind)
}
Expand All @@ -156,11 +161,12 @@ impl Provider {
&self,
peer: SocketAddr,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> std::io::Result<(secret::map::Peer, HandshakeKind)> {
// Unconditionally request a background handshake. This schedules any re-handshaking
// needed.
if self.state.runtime.is_some() {
let _ = self.background_handshake_with(peer, query_event_callback);
let _ = self.background_handshake_with(peer, query_event_callback, server_name.clone());
}

if let Some(peer) = self.state.map.get_tracked(peer) {
Expand All @@ -170,10 +176,18 @@ impl Provider {
let state = self.state.clone();
if let Some((runtime, _)) = self.state.runtime.as_ref() {
runtime
.spawn(async move { state.client.connect(peer, query_event_callback).await })
.spawn(async move {
state
.client
.connect(peer, query_event_callback, server_name)
.await
})
.await??;
} else {
state.client.connect(peer, query_event_callback).await?;
state
.client
.connect(peer, query_event_callback, server_name)
.await?;
}

// already recorded a metric above in get_tracked.
Expand All @@ -193,18 +207,21 @@ impl Provider {
&self,
peer: SocketAddr,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> std::io::Result<HandshakeKind> {
if self.state.map.contains(&peer) {
return Ok(HandshakeKind::Cached);
}

let client = self.state.client.clone();
if let Some((runtime, _)) = self.state.runtime.as_ref() {
let server_name = server_name.clone();
// Drop the JoinHandle -- we're not actually going to block on the join handle's
// result. The future will keep running in the background.
runtime.spawn(async move {
if let Err(HandshakeFailed { .. }) =
client.connect(peer, query_event_callback).await
if let Err(HandshakeFailed { .. }) = client
.connect(peer, query_event_callback, server_name)
.await
{
// error already logged
}
Expand All @@ -229,18 +246,22 @@ impl Provider {
&self,
peer: SocketAddr,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> std::io::Result<HandshakeKind> {
// Unconditionally request a background handshake. This schedules any re-handshaking
// needed.
if self.state.runtime.is_some() {
let _ = self.background_handshake_with(peer, query_event_callback);
let _ = self.background_handshake_with(peer, query_event_callback, server_name.clone());
}

if self.state.map.contains(&peer) {
return Ok(HandshakeKind::Cached);
}

let fut = self.state.client.connect(peer, query_event_callback);
let fut = self
.state
.client
.connect(peer, query_event_callback, server_name);
if let Some((runtime, _)) = self.state.runtime.as_ref() {
runtime.block_on(fut)?
} else {
Expand All @@ -260,11 +281,17 @@ impl Provider {
&self,
peer: SocketAddr,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> std::io::Result<secret::map::Peer> {
let state = self.state.clone();
if let Some((runtime, _)) = self.state.runtime.as_ref() {
runtime
.spawn(async move { state.client.connect(peer, query_event_callback).await })
.spawn(async move {
state
.client
.connect(peer, query_event_callback, server_name)
.await
})
.await??;
} else {
return Err(std::io::Error::new(
Expand Down
3 changes: 3 additions & 0 deletions dc/s2n-quic-dc/src/psk/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
};
use s2n_quic::{
provider::{event::Subscriber as Sub, tls::Provider as Prov},
server::Name,
Connection,
};
use std::{net::SocketAddr, time::Duration};
Expand Down Expand Up @@ -102,6 +103,7 @@ impl<Event: s2n_quic::provider::event::Subscriber> Builder<Event> {
tls_materials_provider: TlsProvider,
subscriber: Subscriber,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> Result<Provider> {
Provider::new::<TlsProvider, Subscriber, Event>(
addr,
Expand All @@ -110,6 +112,7 @@ impl<Event: s2n_quic::provider::event::Subscriber> Builder<Event> {
subscriber,
query_event_callback,
self,
server_name,
)
}
}
7 changes: 5 additions & 2 deletions dc/s2n-quic-dc/src/psk/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use s2n_quic::{
event::Subscriber as Sub,
tls::Provider as Prov,
},
server::Name,
Connection,
};
use std::{
Expand Down Expand Up @@ -209,10 +210,11 @@ impl Client {
&self,
peer: SocketAddr,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> Result<(), HandshakeFailed> {
self.queue
.clone()
.handshake(&self.client, peer, query_event_callback)
.handshake(&self.client, peer, query_event_callback, server_name)
.await
}
}
Expand Down Expand Up @@ -314,6 +316,7 @@ impl HandshakeQueue {
client: &s2n_quic::Client,
peer: SocketAddr,
query_event_callback: fn(&mut Connection, Duration),
server_name: Name,
) -> Result<(), HandshakeFailed> {
let entry = self.allocate_entry(peer);
let entry2 = entry.clone();
Expand All @@ -328,7 +331,7 @@ impl HandshakeQueue {
let limiter_duration = start.elapsed();

let mut connection = client
.connect(s2n_quic::client::Connect::new(peer).with_server_name("anyhostname"))
.connect(s2n_quic::client::Connect::new(peer).with_server_name(server_name))
.await?;

query_event_callback(&mut connection, limiter_duration);
Expand Down
60 changes: 46 additions & 14 deletions dc/s2n-quic-dc/src/stream/client/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{
recv, socket,
},
};
use s2n_quic::server::Name;
use s2n_quic_core::time::Clock;
use std::{io, net::SocketAddr, time::Duration};
use tokio::net::TcpStream;
Expand All @@ -33,6 +34,7 @@ pub trait Handshake: Clone {
async fn handshake_with_entry(
&self,
remote_handshake_addr: SocketAddr,
server_name: Name,
) -> std::io::Result<(secret::map::Peer, secret::HandshakeKind)>;

fn local_addr(&self) -> std::io::Result<SocketAddr>;
Expand All @@ -44,8 +46,9 @@ impl Handshake for crate::psk::client::Provider {
async fn handshake_with_entry(
&self,
remote_handshake_addr: SocketAddr,
server_name: Name,
) -> std::io::Result<(secret::map::Peer, secret::HandshakeKind)> {
self.handshake_with_entry(remote_handshake_addr, |_conn, _duration| {})
self.handshake_with_entry(remote_handshake_addr, |_conn, _duration| {}, server_name)
.await
}

Expand Down Expand Up @@ -89,10 +92,11 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
pub async fn handshake_with(
&self,
remote_handshake_addr: SocketAddr,
server_name: Name,
) -> io::Result<secret::HandshakeKind> {
let (_peer, kind) = self
.handshake
.handshake_with_entry(remote_handshake_addr)
.handshake_with_entry(remote_handshake_addr, server_name)
.await?;
Ok(kind)
}
Expand All @@ -101,10 +105,11 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
async fn handshake_for_connect(
&self,
remote_handshake_addr: SocketAddr,
server_name: Name,
) -> io::Result<secret::map::Peer> {
let (peer, _kind) = self
.handshake
.handshake_with_entry(remote_handshake_addr)
.handshake_with_entry(remote_handshake_addr, server_name)
.await?;
Ok(peer)
}
Expand All @@ -115,10 +120,17 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
&self,
handshake_addr: SocketAddr,
acceptor_addr: SocketAddr,
server_name: Name,
) -> io::Result<Stream<S>> {
match self.default_protocol {
socket::Protocol::Udp => self.connect_udp(handshake_addr, acceptor_addr).await,
socket::Protocol::Tcp => self.connect_tcp(handshake_addr, acceptor_addr).await,
socket::Protocol::Udp => {
self.connect_udp(handshake_addr, acceptor_addr, server_name)
.await
}
socket::Protocol::Tcp => {
self.connect_tcp(handshake_addr, acceptor_addr, server_name)
.await
}
protocol => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid default protocol {protocol:?}"),
Expand All @@ -133,19 +145,32 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
acceptor_addr: SocketAddr,
request: Req,
response: Res,
server_name: Name,
) -> io::Result<Res::Output>
where
Req: rpc::Request,
Res: rpc::Response,
{
match self.default_protocol {
socket::Protocol::Udp => {
self.rpc_udp(handshake_addr, acceptor_addr, request, response)
.await
self.rpc_udp(
handshake_addr,
acceptor_addr,
request,
response,
server_name,
)
.await
}
socket::Protocol::Tcp => {
self.rpc_tcp(handshake_addr, acceptor_addr, request, response)
.await
self.rpc_tcp(
handshake_addr,
acceptor_addr,
request,
response,
server_name,
)
.await
}
protocol => Err(io::Error::new(
io::ErrorKind::InvalidInput,
Expand All @@ -160,9 +185,10 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
&self,
handshake_addr: SocketAddr,
acceptor_addr: SocketAddr,
server_name: Name,
) -> io::Result<Stream<S>> {
// ensure we have a secret for the peer
let handshake = self.handshake_for_connect(handshake_addr);
let handshake = self.handshake_for_connect(handshake_addr, server_name);

let mut stream = client::connect_udp(handshake, acceptor_addr, &self.env).await?;
Self::write_prelude(&mut stream).await?;
Expand All @@ -177,13 +203,14 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
acceptor_addr: SocketAddr,
request: Req,
response: Res,
server_name: Name,
) -> io::Result<Res::Output>
where
Req: rpc::Request,
Res: rpc::Response,
{
// ensure we have a secret for the peer
let handshake = self.handshake_for_connect(handshake_addr);
let handshake = self.handshake_for_connect(handshake_addr, server_name);

let stream = client::connect_udp(handshake, acceptor_addr, &self.env).await?;
rpc_internal::from_stream(stream, request, response).await
Expand All @@ -195,9 +222,10 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
&self,
handshake_addr: SocketAddr,
acceptor_addr: SocketAddr,
server_name: Name,
) -> io::Result<Stream<S>> {
// ensure we have a secret for the peer
let handshake = self.handshake_for_connect(handshake_addr);
let handshake = self.handshake_for_connect(handshake_addr, server_name);

let mut stream =
client::connect_tcp(handshake, acceptor_addr, &self.env, self.linger).await?;
Expand All @@ -213,13 +241,14 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
acceptor_addr: SocketAddr,
request: Req,
response: Res,
server_name: Name,
) -> io::Result<Res::Output>
where
Req: rpc::Request,
Res: rpc::Response,
{
// ensure we have a secret for the peer
let handshake = self.handshake_for_connect(handshake_addr);
let handshake = self.handshake_for_connect(handshake_addr, server_name);

let stream = client::connect_tcp(handshake, acceptor_addr, &self.env, self.linger).await?;
rpc_internal::from_stream(stream, request, response).await
Expand All @@ -231,9 +260,12 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
&self,
handshake_addr: SocketAddr,
stream: TcpStream,
server_name: Name,
) -> io::Result<Stream<S>> {
// ensure we have a secret for the peer
let handshake = self.handshake_for_connect(handshake_addr).await?;
let handshake = self
.handshake_for_connect(handshake_addr, server_name)
.await?;

let mut stream = client::connect_tcp_with(handshake, stream, &self.env).await?;
Self::write_prelude(&mut stream).await?;
Expand Down
Loading