Skip to content

Commit a54a9a7

Browse files
authored
refactor(gateway): misc shard clean up (#2468)
- move connect into its own fn - move serialization into `Pending::event` - separate the `poll_acquire` and `start_send` blocks - simplify `next_event` - use more if-let statements
1 parent 007889d commit a54a9a7

File tree

3 files changed

+100
-137
lines changed

3 files changed

+100
-137
lines changed

twilight-gateway/src/message.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl Message {
5858

5959
/// Convert a `twilight` websocket message into a `tokio-websockets` websocket
6060
/// message.
61-
pub(crate) fn into_websocket_msg(self) -> WebsocketMessage {
61+
pub(crate) fn into_websocket(self) -> WebsocketMessage {
6262
match self {
6363
Self::Close(frame) => WebsocketMessage::close(
6464
frame

twilight-gateway/src/shard.rs

Lines changed: 92 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{
2020
};
2121
use futures_core::Stream;
2222
use futures_sink::Sink;
23-
use serde::{Deserialize, de::DeserializeOwned};
23+
use serde::{Deserialize, Serialize, de::DeserializeOwned};
2424
use std::{
2525
env::consts::OS,
2626
error::Error,
@@ -29,12 +29,13 @@ use std::{
2929
io,
3030
pin::Pin,
3131
str,
32+
sync::Arc,
3233
task::{Context, Poll, ready},
3334
};
3435
use tokio::{
3536
net::TcpStream,
3637
sync::oneshot,
37-
time::{self, Duration, Instant, Interval, MissedTickBehavior, error::Elapsed, timeout},
38+
time::{self, Duration, Instant, Interval, MissedTickBehavior},
3839
};
3940
use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
4041
use twilight_model::gateway::{
@@ -67,38 +68,11 @@ const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
6768
/// [`tokio_websockets`] library Websocket connection.
6869
type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
6970

70-
/// Wrapper enum around [`WebsocketError`] with a timeout case.
71-
enum ConnectionError {
72-
/// Connection attempt timed out.
73-
Timeout(Elapsed),
74-
/// Error from the websocket library, [`tokio_websockets`].
75-
Websocket(WebsocketError),
76-
}
77-
78-
impl ConnectionError {
79-
/// Returns the boxed wrapped error.
80-
fn into_boxed_error(self) -> Box<dyn Error + Send + Sync> {
81-
match self {
82-
Self::Websocket(e) => Box::new(e),
83-
Self::Timeout(e) => Box::new(e),
84-
}
85-
}
86-
}
87-
88-
impl From<WebsocketError> for ConnectionError {
89-
fn from(value: WebsocketError) -> Self {
90-
Self::Websocket(value)
91-
}
92-
}
93-
94-
impl From<Elapsed> for ConnectionError {
95-
fn from(value: Elapsed) -> Self {
96-
Self::Timeout(value)
97-
}
98-
}
71+
/// Dynamically dispatched [`Error`].
72+
type GenericError = Box<dyn Error + Send + Sync>;
9973

10074
/// Wrapper struct around an `async fn` with a `Debug` implementation.
101-
struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, ConnectionError>> + Send>>);
75+
struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, GenericError>> + Send>>);
10276

10377
impl fmt::Debug for ConnectionFuture {
10478
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -213,9 +187,11 @@ struct Pending {
213187

214188
impl Pending {
215189
/// Constructor for a pending gateway event.
216-
const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
190+
fn event<T: Serialize>(event: T, is_heartbeat: bool) -> Option<Self> {
217191
Some(Self {
218-
gateway_event: Some(Message::Text(json)),
192+
gateway_event: Some(Message::Text(
193+
json::to_string(&event).expect("json serialization is infallible"),
194+
)),
219195
is_heartbeat,
220196
})
221197
}
@@ -568,6 +544,67 @@ impl<Q> Shard<Q> {
568544
source: Some(Box::new(source)),
569545
})
570546
}
547+
548+
/// Attempts to connect to the gateway.
549+
///
550+
/// # Returns
551+
///
552+
/// * `Poll::Pending` if connection is in progress
553+
/// * `Poll::Ready(Ok)` if connected
554+
/// * `Poll::Ready(Err)` if connecting to the gateway failed.
555+
fn poll_connect(
556+
&mut self,
557+
cx: &mut Context<'_>,
558+
attempt: u8,
559+
) -> Poll<Result<(), ReceiveMessageError>> {
560+
let fut = self.connection_future.get_or_insert_with(|| {
561+
let base_url = self
562+
.resume_url
563+
.as_deref()
564+
.or_else(|| self.config.proxy_url())
565+
.unwrap_or(GATEWAY_URL);
566+
let base_url_len = base_url.len();
567+
let uri = format!("{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}");
568+
569+
let tls = Arc::clone(&self.config.tls);
570+
ConnectionFuture(Box::pin(async move {
571+
let delay = Duration::from_secs(2u8.saturating_pow(attempt.into()).into());
572+
time::sleep(delay).await;
573+
tracing::debug!(url = &uri[..base_url_len], "connecting");
574+
575+
let builder = ClientBuilder::new()
576+
.uri(&uri)
577+
.expect("valid URL")
578+
.limits(Limits::unlimited())
579+
.connector(&tls);
580+
Ok(time::timeout(CONNECT_TIMEOUT, builder.connect()).await??.0)
581+
}))
582+
});
583+
584+
let res = ready!(Pin::new(&mut fut.0).poll(cx));
585+
self.connection_future = None;
586+
match res {
587+
Ok(connection) => {
588+
self.connection = Some(connection);
589+
self.state = ShardState::Identifying;
590+
#[cfg(any(feature = "zlib", feature = "zstd"))]
591+
self.decompressor.reset();
592+
}
593+
Err(source) => {
594+
self.resume_url = None;
595+
self.state = ShardState::Disconnected {
596+
reconnect_attempts: attempt + 1,
597+
};
598+
599+
return Poll::Ready(Err(ReceiveMessageError {
600+
kind: ReceiveMessageErrorType::Reconnect,
601+
source: Some(source),
602+
}));
603+
}
604+
}
605+
606+
Poll::Ready(Ok(()))
607+
}
571608
}
572609

573610
impl<Q: Queue> Shard<Q> {
@@ -583,16 +620,14 @@ impl<Q: Queue> Shard<Q> {
583620
if let Some(pending) = self.pending.as_mut() {
584621
ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
585622

586-
if let Some(message) = &pending.gateway_event {
587-
if let Some(ratelimiter) = self.ratelimiter.as_mut()
588-
&& message.is_text()
589-
&& !pending.is_heartbeat
590-
{
591-
ready!(ratelimiter.poll_acquire(cx));
592-
}
623+
let is_ratelimited = pending.gateway_event.as_ref().is_some_and(Message::is_text)
624+
&& !pending.is_heartbeat;
625+
if is_ratelimited && let Some(ratelimiter) = &mut self.ratelimiter {
626+
ready!(ratelimiter.poll_acquire(cx));
627+
}
593628

594-
let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
595-
Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
629+
if let Some(msg) = pending.gateway_event.take().map(Message::into_websocket) {
630+
Pin::new(self.connection.as_mut().unwrap()).start_send(msg)?;
596631
}
597632

598633
ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
@@ -614,10 +649,8 @@ impl<Q: Queue> Shard<Q> {
614649
continue;
615650
}
616651

617-
if self
618-
.heartbeat_interval
619-
.as_mut()
620-
.is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
652+
if let Some(heartbeater) = &mut self.heartbeat_interval
653+
&& heartbeater.poll_tick(cx).is_ready()
621654
{
622655
// Discord never responded after the last heartbeat, connection
623656
// is failed or "zombied", see
@@ -631,11 +664,8 @@ impl<Q: Queue> Shard<Q> {
631664
}
632665

633666
tracing::debug!("sending heartbeat");
634-
self.pending = Pending::text(
635-
json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
636-
.expect("serialization cannot fail"),
637-
true,
638-
);
667+
self.pending =
668+
Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
639669
self.heartbeat_interval_event = false;
640670

641671
continue;
@@ -647,10 +677,8 @@ impl<Q: Queue> Shard<Q> {
647677
.is_none_or(|ratelimiter| ratelimiter.poll_available(cx).is_ready());
648678

649679
if not_ratelimited
650-
&& let Some(Poll::Ready(canceled)) = self
651-
.identify_rx
652-
.as_mut()
653-
.map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
680+
&& let Some(rx) = &mut self.identify_rx
681+
&& let Poll::Ready(canceled) = Pin::new(rx).poll(cx).map(|r| r.is_err())
654682
{
655683
if canceled {
656684
self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
@@ -659,8 +687,8 @@ impl<Q: Queue> Shard<Q> {
659687

660688
tracing::debug!("sending identify");
661689

662-
self.pending = Pending::text(
663-
json::to_string(&Identify::new(IdentifyInfo {
690+
self.pending = Pending::event(
691+
Identify::new(IdentifyInfo {
664692
compress: false,
665693
intents: self.config.intents(),
666694
large_threshold: self.config.large_threshold(),
@@ -672,8 +700,7 @@ impl<Q: Queue> Shard<Q> {
672700
.unwrap_or_else(default_identify_properties),
673701
shard: Some(self.id),
674702
token: self.config.token().to_owned(),
675-
}))
676-
.expect("serialization cannot fail"),
703+
}),
677704
false,
678705
);
679706
self.identify_rx = None;
@@ -757,11 +784,8 @@ impl<Q: Queue> Shard<Q> {
757784
}
758785
Some(OpCode::Heartbeat) => {
759786
tracing::debug!("received heartbeat");
760-
self.pending = Pending::text(
761-
json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
762-
.expect("serialization cannot fail"),
763-
true,
764-
);
787+
self.pending =
788+
Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
765789
}
766790
Some(OpCode::HeartbeatAck) => {
767791
let requested = self.latency.received().is_none() && self.latency.sent().is_some();
@@ -793,13 +817,8 @@ impl<Q: Queue> Shard<Q> {
793817
self.latency = Latency::new();
794818

795819
if let Some(session) = &self.session {
796-
self.pending = Pending::text(
797-
json::to_string(&Resume::new(
798-
session.sequence(),
799-
session.id(),
800-
self.config.token(),
801-
))
802-
.expect("serialization cannot fail"),
820+
self.pending = Pending::event(
821+
Resume::new(session.sequence(), session.id(), self.config.token()),
803822
false,
804823
);
805824
self.state = ShardState::Resuming;
@@ -830,7 +849,6 @@ impl<Q: Queue> Shard<Q> {
830849
impl<Q: Queue + Unpin> Stream for Shard<Q> {
831850
type Item = Result<Message, ReceiveMessageError>;
832851

833-
#[allow(clippy::too_many_lines)]
834852
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
835853
let message = loop {
836854
match self.state {
@@ -847,59 +865,7 @@ impl<Q: Queue + Unpin> Stream for Shard<Q> {
847865
return Poll::Ready(None);
848866
}
849867
ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
850-
if self.connection_future.is_none() {
851-
let base_url = self
852-
.resume_url
853-
.as_deref()
854-
.or_else(|| self.config.proxy_url())
855-
.unwrap_or(GATEWAY_URL);
856-
let uri = format!(
857-
"{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
858-
);
859-
860-
tracing::debug!(url = base_url, "connecting to gateway");
861-
862-
let tls = self.config.tls.clone();
863-
self.connection_future = Some(ConnectionFuture(Box::pin(async move {
864-
let secs = 2u8.saturating_pow(reconnect_attempts.into());
865-
time::sleep(Duration::from_secs(secs.into())).await;
866-
867-
Ok(timeout(
868-
CONNECT_TIMEOUT,
869-
ClientBuilder::new()
870-
.uri(&uri)
871-
.expect("URL should be valid")
872-
.limits(Limits::unlimited())
873-
.connector(&tls)
874-
.connect(),
875-
)
876-
.await??
877-
.0)
878-
})));
879-
}
880-
881-
let res =
882-
ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
883-
self.connection_future = None;
884-
match res {
885-
Ok(connection) => {
886-
self.connection = Some(connection);
887-
self.state = ShardState::Identifying;
888-
#[cfg(any(feature = "zlib", feature = "zstd"))]
889-
self.decompressor.reset();
890-
}
891-
Err(source) => {
892-
self.resume_url = None;
893-
self.state = ShardState::Disconnected {
894-
reconnect_attempts: reconnect_attempts + 1,
895-
};
896-
897-
return Poll::Ready(Some(Err(ReceiveMessageError {
898-
kind: ReceiveMessageErrorType::Reconnect,
899-
source: Some(source.into_boxed_error()),
900-
})));
901-
}
902-
}
868+
ready!(self.poll_connect(cx, reconnect_attempts))?;
903869
}
904870
_ => {}
905871
}

twilight-gateway/src/stream.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,14 @@ mod private {
110110
type Output = Option<Result<Event, ReceiveMessageError>>;
111111

112112
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113-
let events = self.events;
114-
let try_from_message = |message| match message {
115-
Message::Text(json) => parse(json, events).map(|opt| opt.map(Into::into)),
116-
Message::Close(frame) => Ok(Some(Event::GatewayClose(frame))),
117-
};
118-
119113
loop {
120-
match ready!(Pin::new(&mut self.stream).poll_next(cx)) {
121-
Some(item) => {
122-
if let Some(event) = item.and_then(try_from_message).transpose() {
123-
return Poll::Ready(Some(event));
114+
match ready!(Pin::new(&mut self.stream).poll_next(cx)?) {
115+
Some(Message::Close(frame)) => {
116+
return Poll::Ready(Some(Ok(Event::GatewayClose(frame))));
117+
}
118+
Some(Message::Text(json)) => {
119+
if let Some(event) = parse(json, self.events).transpose() {
120+
return Poll::Ready(Some(event.map(Into::into)));
124121
}
125122
}
126123
None => return Poll::Ready(None),

0 commit comments

Comments
 (0)