@@ -20,7 +20,7 @@ use crate::{
2020} ;
2121use futures_core:: Stream ;
2222use futures_sink:: Sink ;
23- use serde:: { Deserialize , de:: DeserializeOwned } ;
23+ use serde:: { Deserialize , Serialize , de:: DeserializeOwned } ;
2424use std:: {
2525 env:: consts:: OS ,
2626 error:: Error ,
@@ -29,12 +29,13 @@ use std::{
2929 io,
3030 pin:: Pin ,
3131 str,
32+ sync:: Arc ,
3233 task:: { Context , Poll , ready} ,
3334} ;
3435use tokio:: {
3536 net:: TcpStream ,
3637 sync:: oneshot,
37- time:: { self , Duration , Instant , Interval , MissedTickBehavior , error :: Elapsed , timeout } ,
38+ time:: { self , Duration , Instant , Interval , MissedTickBehavior } ,
3839} ;
3940use tokio_websockets:: { ClientBuilder , Error as WebsocketError , Limits , MaybeTlsStream } ;
4041use twilight_model:: gateway:: {
@@ -67,38 +68,11 @@ const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
6768/// [`tokio_websockets`] library Websocket connection.
6869type Connection = tokio_websockets:: WebSocketStream < MaybeTlsStream < TcpStream > > ;
6970
70- /// Wrapper enum around [`WebsocketError`] with a timeout case.
71- enum ConnectionError {
72- /// Connection attempt timed out.
73- Timeout ( Elapsed ) ,
74- /// Error from the websocket library, [`tokio_websockets`].
75- Websocket ( WebsocketError ) ,
76- }
77-
78- impl ConnectionError {
79- /// Returns the boxed wrapped error.
80- fn into_boxed_error ( self ) -> Box < dyn Error + Send + Sync > {
81- match self {
82- Self :: Websocket ( e) => Box :: new ( e) ,
83- Self :: Timeout ( e) => Box :: new ( e) ,
84- }
85- }
86- }
87-
88- impl From < WebsocketError > for ConnectionError {
89- fn from ( value : WebsocketError ) -> Self {
90- Self :: Websocket ( value)
91- }
92- }
93-
94- impl From < Elapsed > for ConnectionError {
95- fn from ( value : Elapsed ) -> Self {
96- Self :: Timeout ( value)
97- }
98- }
71+ /// Dynamically dispatched [`Error`].
72+ type GenericError = Box < dyn Error + Send + Sync > ;
9973
10074/// Wrapper struct around an `async fn` with a `Debug` implementation.
101- struct ConnectionFuture ( Pin < Box < dyn Future < Output = Result < Connection , ConnectionError > > + Send > > ) ;
75+ struct ConnectionFuture ( Pin < Box < dyn Future < Output = Result < Connection , GenericError > > + Send > > ) ;
10276
10377impl fmt:: Debug for ConnectionFuture {
10478 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
@@ -213,9 +187,11 @@ struct Pending {
213187
214188impl Pending {
215189 /// Constructor for a pending gateway event.
216- const fn text ( json : String , is_heartbeat : bool ) -> Option < Self > {
190+ fn event < T : Serialize > ( event : T , is_heartbeat : bool ) -> Option < Self > {
217191 Some ( Self {
218- gateway_event : Some ( Message :: Text ( json) ) ,
192+ gateway_event : Some ( Message :: Text (
193+ json:: to_string ( & event) . expect ( "json serialization is infallible" ) ,
194+ ) ) ,
219195 is_heartbeat,
220196 } )
221197 }
@@ -568,6 +544,67 @@ impl<Q> Shard<Q> {
568544 source : Some ( Box :: new ( source) ) ,
569545 } )
570546 }
547+
548+ /// Attempts to connect to the gateway.
549+ ///
550+ /// # Returns
551+ ///
552+ /// * `Poll::Pending` if connection is in progress
553+ /// * `Poll::Ready(Ok)` if connected
554+ /// * `Poll::Ready(Err)` if connecting to the gateway failed.
555+ fn poll_connect (
556+ & mut self ,
557+ cx : & mut Context < ' _ > ,
558+ attempt : u8 ,
559+ ) -> Poll < Result < ( ) , ReceiveMessageError > > {
560+ let fut = self . connection_future . get_or_insert_with ( || {
561+ let base_url = self
562+ . resume_url
563+ . as_deref ( )
564+ . or_else ( || self . config . proxy_url ( ) )
565+ . unwrap_or ( GATEWAY_URL ) ;
566+ let base_url_len = base_url. len ( ) ;
567+ let uri = format ! ( "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}" ) ;
568+
569+ let tls = Arc :: clone ( & self . config . tls ) ;
570+ ConnectionFuture ( Box :: pin ( async move {
571+ let delay = Duration :: from_secs ( 2u8 . saturating_pow ( attempt. into ( ) ) . into ( ) ) ;
572+ time:: sleep ( delay) . await ;
573+ tracing:: debug!( url = & uri[ ..base_url_len] , "connecting" ) ;
574+
575+ let builder = ClientBuilder :: new ( )
576+ . uri ( & uri)
577+ . expect ( "valid URL" )
578+ . limits ( Limits :: unlimited ( ) )
579+ . connector ( & tls) ;
580+ Ok ( time:: timeout ( CONNECT_TIMEOUT , builder. connect ( ) ) . await ??. 0 )
581+ } ) )
582+ } ) ;
583+
584+ let res = ready ! ( Pin :: new( & mut fut. 0 ) . poll( cx) ) ;
585+ self . connection_future = None ;
586+ match res {
587+ Ok ( connection) => {
588+ self . connection = Some ( connection) ;
589+ self . state = ShardState :: Identifying ;
590+ #[ cfg( any( feature = "zlib" , feature = "zstd" ) ) ]
591+ self . decompressor . reset ( ) ;
592+ }
593+ Err ( source) => {
594+ self . resume_url = None ;
595+ self . state = ShardState :: Disconnected {
596+ reconnect_attempts : attempt + 1 ,
597+ } ;
598+
599+ return Poll :: Ready ( Err ( ReceiveMessageError {
600+ kind : ReceiveMessageErrorType :: Reconnect ,
601+ source : Some ( source) ,
602+ } ) ) ;
603+ }
604+ }
605+
606+ Poll :: Ready ( Ok ( ( ) ) )
607+ }
571608}
572609
573610impl < Q : Queue > Shard < Q > {
@@ -583,16 +620,14 @@ impl<Q: Queue> Shard<Q> {
583620 if let Some ( pending) = self . pending . as_mut ( ) {
584621 ready ! ( Pin :: new( self . connection. as_mut( ) . unwrap( ) ) . poll_ready( cx) ) ?;
585622
586- if let Some ( message) = & pending. gateway_event {
587- if let Some ( ratelimiter) = self . ratelimiter . as_mut ( )
588- && message. is_text ( )
589- && !pending. is_heartbeat
590- {
591- ready ! ( ratelimiter. poll_acquire( cx) ) ;
592- }
623+ let is_ratelimited = pending. gateway_event . as_ref ( ) . is_some_and ( Message :: is_text)
624+ && !pending. is_heartbeat ;
625+ if is_ratelimited && let Some ( ratelimiter) = & mut self . ratelimiter {
626+ ready ! ( ratelimiter. poll_acquire( cx) ) ;
627+ }
593628
594- let ws_message = pending. gateway_event . take ( ) . unwrap ( ) . into_websocket_msg ( ) ;
595- Pin :: new ( self . connection . as_mut ( ) . unwrap ( ) ) . start_send ( ws_message ) ?;
629+ if let Some ( msg ) = pending. gateway_event . take ( ) . map ( Message :: into_websocket ) {
630+ Pin :: new ( self . connection . as_mut ( ) . unwrap ( ) ) . start_send ( msg ) ?;
596631 }
597632
598633 ready ! ( Pin :: new( self . connection. as_mut( ) . unwrap( ) ) . poll_flush( cx) ) ?;
@@ -614,10 +649,8 @@ impl<Q: Queue> Shard<Q> {
614649 continue ;
615650 }
616651
617- if self
618- . heartbeat_interval
619- . as_mut ( )
620- . is_some_and ( |heartbeater| heartbeater. poll_tick ( cx) . is_ready ( ) )
652+ if let Some ( heartbeater) = & mut self . heartbeat_interval
653+ && heartbeater. poll_tick ( cx) . is_ready ( )
621654 {
622655 // Discord never responded after the last heartbeat, connection
623656 // is failed or "zombied", see
@@ -631,11 +664,8 @@ impl<Q: Queue> Shard<Q> {
631664 }
632665
633666 tracing:: debug!( "sending heartbeat" ) ;
634- self . pending = Pending :: text (
635- json:: to_string ( & Heartbeat :: new ( self . session ( ) . map ( Session :: sequence) ) )
636- . expect ( "serialization cannot fail" ) ,
637- true ,
638- ) ;
667+ self . pending =
668+ Pending :: event ( Heartbeat :: new ( self . session ( ) . map ( Session :: sequence) ) , true ) ;
639669 self . heartbeat_interval_event = false ;
640670
641671 continue ;
@@ -647,10 +677,8 @@ impl<Q: Queue> Shard<Q> {
647677 . is_none_or ( |ratelimiter| ratelimiter. poll_available ( cx) . is_ready ( ) ) ;
648678
649679 if not_ratelimited
650- && let Some ( Poll :: Ready ( canceled) ) = self
651- . identify_rx
652- . as_mut ( )
653- . map ( |rx| Pin :: new ( rx) . poll ( cx) . map ( |r| r. is_err ( ) ) )
680+ && let Some ( rx) = & mut self . identify_rx
681+ && let Poll :: Ready ( canceled) = Pin :: new ( rx) . poll ( cx) . map ( |r| r. is_err ( ) )
654682 {
655683 if canceled {
656684 self . identify_rx = Some ( self . config . queue ( ) . enqueue ( self . id . number ( ) ) ) ;
@@ -659,8 +687,8 @@ impl<Q: Queue> Shard<Q> {
659687
660688 tracing:: debug!( "sending identify" ) ;
661689
662- self . pending = Pending :: text (
663- json :: to_string ( & Identify :: new ( IdentifyInfo {
690+ self . pending = Pending :: event (
691+ Identify :: new ( IdentifyInfo {
664692 compress : false ,
665693 intents : self . config . intents ( ) ,
666694 large_threshold : self . config . large_threshold ( ) ,
@@ -672,8 +700,7 @@ impl<Q: Queue> Shard<Q> {
672700 . unwrap_or_else ( default_identify_properties) ,
673701 shard : Some ( self . id ) ,
674702 token : self . config . token ( ) . to_owned ( ) ,
675- } ) )
676- . expect ( "serialization cannot fail" ) ,
703+ } ) ,
677704 false ,
678705 ) ;
679706 self . identify_rx = None ;
@@ -757,11 +784,8 @@ impl<Q: Queue> Shard<Q> {
757784 }
758785 Some ( OpCode :: Heartbeat ) => {
759786 tracing:: debug!( "received heartbeat" ) ;
760- self . pending = Pending :: text (
761- json:: to_string ( & Heartbeat :: new ( self . session ( ) . map ( Session :: sequence) ) )
762- . expect ( "serialization cannot fail" ) ,
763- true ,
764- ) ;
787+ self . pending =
788+ Pending :: event ( Heartbeat :: new ( self . session ( ) . map ( Session :: sequence) ) , true ) ;
765789 }
766790 Some ( OpCode :: HeartbeatAck ) => {
767791 let requested = self . latency . received ( ) . is_none ( ) && self . latency . sent ( ) . is_some ( ) ;
@@ -793,13 +817,8 @@ impl<Q: Queue> Shard<Q> {
793817 self . latency = Latency :: new ( ) ;
794818
795819 if let Some ( session) = & self . session {
796- self . pending = Pending :: text (
797- json:: to_string ( & Resume :: new (
798- session. sequence ( ) ,
799- session. id ( ) ,
800- self . config . token ( ) ,
801- ) )
802- . expect ( "serialization cannot fail" ) ,
820+ self . pending = Pending :: event (
821+ Resume :: new ( session. sequence ( ) , session. id ( ) , self . config . token ( ) ) ,
803822 false ,
804823 ) ;
805824 self . state = ShardState :: Resuming ;
@@ -830,7 +849,6 @@ impl<Q: Queue> Shard<Q> {
830849impl < Q : Queue + Unpin > Stream for Shard < Q > {
831850 type Item = Result < Message , ReceiveMessageError > ;
832851
833- #[ allow( clippy:: too_many_lines) ]
834852 fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
835853 let message = loop {
836854 match self . state {
@@ -847,59 +865,7 @@ impl<Q: Queue + Unpin> Stream for Shard<Q> {
847865 return Poll :: Ready ( None ) ;
848866 }
849867 ShardState :: Disconnected { reconnect_attempts } if self . connection . is_none ( ) => {
850- if self . connection_future . is_none ( ) {
851- let base_url = self
852- . resume_url
853- . as_deref ( )
854- . or_else ( || self . config . proxy_url ( ) )
855- . unwrap_or ( GATEWAY_URL ) ;
856- let uri = format ! (
857- "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
858- ) ;
859-
860- tracing:: debug!( url = base_url, "connecting to gateway" ) ;
861-
862- let tls = self . config . tls . clone ( ) ;
863- self . connection_future = Some ( ConnectionFuture ( Box :: pin ( async move {
864- let secs = 2u8 . saturating_pow ( reconnect_attempts. into ( ) ) ;
865- time:: sleep ( Duration :: from_secs ( secs. into ( ) ) ) . await ;
866-
867- Ok ( timeout (
868- CONNECT_TIMEOUT ,
869- ClientBuilder :: new ( )
870- . uri ( & uri)
871- . expect ( "URL should be valid" )
872- . limits ( Limits :: unlimited ( ) )
873- . connector ( & tls)
874- . connect ( ) ,
875- )
876- . await ??
877- . 0 )
878- } ) ) ) ;
879- }
880-
881- let res =
882- ready ! ( Pin :: new( & mut self . connection_future. as_mut( ) . unwrap( ) . 0 ) . poll( cx) ) ;
883- self . connection_future = None ;
884- match res {
885- Ok ( connection) => {
886- self . connection = Some ( connection) ;
887- self . state = ShardState :: Identifying ;
888- #[ cfg( any( feature = "zlib" , feature = "zstd" ) ) ]
889- self . decompressor . reset ( ) ;
890- }
891- Err ( source) => {
892- self . resume_url = None ;
893- self . state = ShardState :: Disconnected {
894- reconnect_attempts : reconnect_attempts + 1 ,
895- } ;
896-
897- return Poll :: Ready ( Some ( Err ( ReceiveMessageError {
898- kind : ReceiveMessageErrorType :: Reconnect ,
899- source : Some ( source. into_boxed_error ( ) ) ,
900- } ) ) ) ;
901- }
902- }
868+ ready ! ( self . poll_connect( cx, reconnect_attempts) ) ?;
903869 }
904870 _ => { }
905871 }
0 commit comments