Skip to content

Commit 67bf227

Browse files
authored
chore: refactor transport trait and runtime architecture (#74)
* chore: refactor transport trait * chore: tidy * chore: improve naming
1 parent 205e39a commit 67bf227

File tree

7 files changed

+153
-89
lines changed

7 files changed

+153
-89
lines changed

crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use axum::{
1919
Extension, Router,
2020
};
2121
use futures::stream::{self};
22+
use rust_mcp_schema::schema_utils::ClientMessage;
2223
use rust_mcp_transport::{error::TransportError, SessionId, SseTransport};
2324
use std::{convert::Infallible, sync::Arc, time::Duration};
2425
use tokio::{
@@ -78,7 +79,7 @@ pub async fn handle_sse(
7879
State(state): State<Arc<AppState>>,
7980
) -> TransportServerResult<impl IntoResponse> {
8081
let messages_endpoint =
81-
SseTransport::message_endpoint(&state.sse_message_endpoint, &session_id);
82+
SseTransport::<ClientMessage>::message_endpoint(&state.sse_message_endpoint, &session_id);
8283

8384
// readable stream of string to be used in transport
8485
let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);

crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,10 @@ pub struct ClientRuntime {
2828
client_details: InitializeRequestParams,
2929
// Details about the connected server
3030
server_details: Arc<RwLock<Option<InitializeResult>>>,
31-
message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ServerMessage>>>,
3231
handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
3332
}
3433

3534
impl ClientRuntime {
36-
pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ServerMessage>) {
37-
let mut lock = self.message_sender.write().await;
38-
*lock = Some(sender);
39-
}
40-
4135
pub(crate) fn new(
4236
client_details: InitializeRequestParams,
4337
transport: impl Transport<ServerMessage, MessageFromClient>,
@@ -48,7 +42,6 @@ impl ClientRuntime {
4842
handler,
4943
client_details,
5044
server_details: Arc::new(RwLock::new(None)),
51-
message_sender: tokio::sync::RwLock::new(None),
5245
handlers: Mutex::new(vec![]),
5346
}
5447
}
@@ -83,20 +76,22 @@ impl McpClient for ClientRuntime {
8376
where
8477
MessageDispatcher<ServerMessage>: McpDispatch<ServerMessage, MessageFromClient>,
8578
{
86-
(&self.message_sender) as _
79+
(self.transport.message_sender().await) as _
8780
}
8881

8982
async fn start(self: Arc<Self>) -> SdkResult<()> {
90-
let (mut stream, sender, error_io) = self.transport.start().await?;
91-
self.set_message_sender(sender).await;
83+
let mut stream = self.transport.start().await?;
84+
85+
let mut error_io_stream = self.transport.error_stream().await.write().await;
86+
let error_io_stream = error_io_stream.take();
9287

9388
let self_clone = Arc::clone(&self);
9489
let self_clone_err = Arc::clone(&self);
9590

9691
let err_task = tokio::spawn(async move {
9792
let self_ref = &*self_clone_err;
9893

99-
if let IoStream::Readable(error_input) = error_io {
94+
if let Some(IoStream::Readable(error_input)) = error_io_stream {
10095
let mut reader = BufReader::new(error_input).lines();
10196
loop {
10297
tokio::select! {
@@ -126,6 +121,7 @@ impl McpClient for ClientRuntime {
126121
}
127122
}
128123
}
124+
129125
Ok::<(), McpSdkError>(())
130126
});
131127

crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use async_trait::async_trait;
77
use futures::StreamExt;
88
use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
99
use schema_utils::ClientMessage;
10-
use std::pin::Pin;
1110
use std::sync::{Arc, RwLock};
1211
use tokio::io::AsyncWriteExt;
1312

@@ -27,9 +26,6 @@ pub struct ServerRuntime {
2726
server_details: Arc<InitializeResult>,
2827
// Details about the connected client
2928
client_details: Arc<RwLock<Option<InitializeRequestParams>>>,
30-
31-
message_sender: tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>,
32-
error_stream: tokio::sync::RwLock<Option<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>>,
3329
#[cfg(feature = "hyper-server")]
3430
session_id: Option<SessionId>,
3531
}
@@ -70,24 +66,14 @@ impl McpServer for ServerRuntime {
7066
where
7167
MessageDispatcher<ClientMessage>: McpDispatch<ClientMessage, MessageFromServer>,
7268
{
73-
(&self.message_sender) as _
69+
(self.transport.message_sender().await) as _
7470
}
7571

7672
/// Main runtime loop, processes incoming messages and handles requests
7773
async fn start(&self) -> SdkResult<()> {
78-
// Start the transport layer to begin handling messages
79-
// self.transport.start().await?;
80-
// Open the transport stream
81-
// let mut stream = self.transport.open();
82-
let (mut stream, sender, error_io) = self.transport.start().await?;
83-
84-
self.set_message_sender(sender).await;
85-
86-
if let IoStream::Writable(error_stream) = error_io {
87-
self.set_error_stream(error_stream).await;
88-
}
74+
let mut stream = self.transport.start().await?;
8975

90-
let sender = self.sender().await.read().await;
76+
let sender = self.transport.message_sender().await.read().await;
9177
let sender = sender
9278
.as_ref()
9379
.ok_or(schema_utils::SdkError::connection_closed())?;
@@ -138,8 +124,8 @@ impl McpServer for ServerRuntime {
138124
}
139125

140126
async fn stderr_message(&self, message: String) -> SdkResult<()> {
141-
let mut lock = self.error_stream.write().await;
142-
if let Some(stderr) = lock.as_mut() {
127+
let mut lock = self.transport.error_stream().await.write().await;
128+
if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
143129
stderr.write_all(message.as_bytes()).await?;
144130
stderr.write_all(b"\n").await?;
145131
stderr.flush().await?;
@@ -149,24 +135,11 @@ impl McpServer for ServerRuntime {
149135
}
150136

151137
impl ServerRuntime {
152-
pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<ClientMessage>) {
153-
let mut lock = self.message_sender.write().await;
154-
*lock = Some(sender);
155-
}
156-
157138
#[cfg(feature = "hyper-server")]
158139
pub(crate) async fn session_id(&self) -> Option<SessionId> {
159140
self.session_id.to_owned()
160141
}
161142

162-
pub(crate) async fn set_error_stream(
163-
&self,
164-
error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
165-
) {
166-
let mut lock = self.error_stream.write().await;
167-
*lock = Some(error_stream);
168-
}
169-
170143
#[cfg(feature = "hyper-server")]
171144
pub(crate) fn new_instance(
172145
server_details: Arc<InitializeResult>,
@@ -179,8 +152,6 @@ impl ServerRuntime {
179152
client_details: Arc::new(RwLock::new(None)),
180153
transport: Box::new(transport),
181154
handler,
182-
message_sender: tokio::sync::RwLock::new(None),
183-
error_stream: tokio::sync::RwLock::new(None),
184155
session_id: Some(session_id),
185156
}
186157
}
@@ -195,8 +166,6 @@ impl ServerRuntime {
195166
client_details: Arc::new(RwLock::new(None)),
196167
transport: Box::new(transport),
197168
handler,
198-
message_sender: tokio::sync::RwLock::new(None),
199-
error_stream: tokio::sync::RwLock::new(None),
200169
#[cfg(feature = "hyper-server")]
201170
session_id: None,
202171
}

crates/rust-mcp-transport/src/client_sse.rs

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ impl Default for ClientSseTransportOptions {
5252
/// Client-side Server-Sent Events (SSE) transport implementation
5353
///
5454
/// Manages SSE connections, HTTP POST requests, and message streaming for client-server communication.
55-
pub struct ClientSseTransport {
55+
pub struct ClientSseTransport<R>
56+
where
57+
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
58+
{
5659
/// Optional cancellation token source for shutting down the transport
5760
shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
5861
/// Flag indicating if the transport is shut down
@@ -73,9 +76,14 @@ pub struct ClientSseTransport {
7376
custom_headers: Option<HeaderMap>,
7477
sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
7578
post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
79+
message_sender: tokio::sync::RwLock<Option<MessageDispatcher<R>>>,
80+
error_stream: tokio::sync::RwLock<Option<IoStream>>,
7681
}
7782

78-
impl ClientSseTransport {
83+
impl<R> ClientSseTransport<R>
84+
where
85+
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
86+
{
7987
/// Creates a new ClientSseTransport instance
8088
///
8189
/// Initializes the transport with the provided server URL and options.
@@ -111,6 +119,8 @@ impl ClientSseTransport {
111119
custom_headers: headers,
112120
sse_task: tokio::sync::RwLock::new(None),
113121
post_task: tokio::sync::RwLock::new(None),
122+
message_sender: tokio::sync::RwLock::new(None),
123+
error_stream: tokio::sync::RwLock::new(None),
114124
})
115125
}
116126

@@ -161,10 +171,23 @@ impl ClientSseTransport {
161171
}
162172
Ok(endpoint)
163173
}
174+
175+
pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
176+
let mut lock = self.message_sender.write().await;
177+
*lock = Some(sender);
178+
}
179+
180+
pub(crate) async fn set_error_stream(
181+
&self,
182+
error_stream: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>,
183+
) {
184+
let mut lock = self.error_stream.write().await;
185+
*lock = Some(IoStream::Readable(error_stream));
186+
}
164187
}
165188

166189
#[async_trait]
167-
impl<R, S> Transport<R, S> for ClientSseTransport
190+
impl<R, S> Transport<R, S> for ClientSseTransport<R>
168191
where
169192
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
170193
S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
@@ -176,13 +199,7 @@ where
176199
/// # Returns
177200
/// * `TransportResult<(Pin<Box<dyn Stream<Item = R> + Send>>, MessageDispatcher<R>, IoStream)>`
178201
/// - The message stream, dispatcher, and error stream
179-
async fn start(
180-
&self,
181-
) -> TransportResult<(
182-
Pin<Box<dyn Stream<Item = R> + Send>>,
183-
MessageDispatcher<R>,
184-
IoStream,
185-
)>
202+
async fn start(&self) -> TransportResult<Pin<Box<dyn Stream<Item = R> + Send>>>
186203
where
187204
MessageDispatcher<R>: McpDispatch<R, S>,
188205
{
@@ -290,7 +307,21 @@ where
290307
cancellation_token,
291308
);
292309

293-
Ok((stream, sender, error_stream))
310+
self.set_message_sender(sender).await;
311+
312+
if let IoStream::Readable(error_stream) = error_stream {
313+
self.set_error_stream(error_stream).await;
314+
}
315+
316+
Ok(stream)
317+
}
318+
319+
async fn message_sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<R>>> {
320+
&self.message_sender as _
321+
}
322+
323+
async fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
324+
&self.error_stream as _
294325
}
295326

296327
/// Checks if the transport has been shut down

crates/rust-mcp-transport/src/sse.rs

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::schema::schema_utils::{McpMessage, RpcMessage};
22
use crate::schema::RequestId;
33
use async_trait::async_trait;
44
use futures::Stream;
5+
use serde::de::DeserializeOwned;
56
use std::collections::HashMap;
67
use std::pin::Pin;
78
use std::sync::Arc;
@@ -15,15 +16,23 @@ use crate::transport::Transport;
1516
use crate::utils::{endpoint_with_session_id, CancellationTokenSource};
1617
use crate::{IoStream, McpDispatch, SessionId, TransportOptions};
1718

18-
pub struct SseTransport {
19+
pub struct SseTransport<R>
20+
where
21+
R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static,
22+
{
1923
shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
2024
is_shut_down: Mutex<bool>,
2125
read_write_streams: Mutex<Option<(DuplexStream, DuplexStream)>>,
2226
options: Arc<TransportOptions>,
27+
message_sender: tokio::sync::RwLock<Option<MessageDispatcher<R>>>,
28+
error_stream: tokio::sync::RwLock<Option<IoStream>>,
2329
}
2430

2531
/// Server-Sent Events (SSE) transport implementation
26-
impl SseTransport {
32+
impl<R> SseTransport<R>
33+
where
34+
R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static,
35+
{
2736
/// Creates a new SseTransport instance
2837
///
2938
/// Initializes the transport with provided read and write duplex streams and options.
@@ -45,16 +54,31 @@ impl SseTransport {
4554
options,
4655
shutdown_source: tokio::sync::RwLock::new(None),
4756
is_shut_down: Mutex::new(false),
57+
message_sender: tokio::sync::RwLock::new(None),
58+
error_stream: tokio::sync::RwLock::new(None),
4859
})
4960
}
5061

5162
pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String {
5263
endpoint_with_session_id(endpoint, session_id)
5364
}
65+
66+
pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
67+
let mut lock = self.message_sender.write().await;
68+
*lock = Some(sender);
69+
}
70+
71+
pub(crate) async fn set_error_stream(
72+
&self,
73+
error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
74+
) {
75+
let mut lock = self.error_stream.write().await;
76+
*lock = Some(IoStream::Writable(error_stream));
77+
}
5478
}
5579

5680
#[async_trait]
57-
impl<R, S> Transport<R, S> for SseTransport
81+
impl<R, S> Transport<R, S> for SseTransport<R>
5882
where
5983
R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
6084
S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
@@ -69,13 +93,7 @@ where
6993
///
7094
/// # Errors
7195
/// * Returns `TransportError` if streams are already taken or not initialized
72-
async fn start(
73-
&self,
74-
) -> TransportResult<(
75-
Pin<Box<dyn Stream<Item = R> + Send>>,
76-
MessageDispatcher<R>,
77-
IoStream,
78-
)>
96+
async fn start(&self) -> TransportResult<Pin<Box<dyn Stream<Item = R> + Send>>>
7997
where
8098
MessageDispatcher<R>: McpDispatch<R, S>,
8199
{
@@ -103,7 +121,13 @@ where
103121
cancellation_token,
104122
);
105123

106-
Ok((stream, sender, error_stream))
124+
self.set_message_sender(sender).await;
125+
126+
if let IoStream::Writable(error_stream) = error_stream {
127+
self.set_error_stream(error_stream).await;
128+
}
129+
130+
Ok(stream)
107131
}
108132

109133
/// Checks if the transport has been shut down
@@ -115,6 +139,14 @@ where
115139
*result
116140
}
117141

142+
async fn message_sender(&self) -> &tokio::sync::RwLock<Option<MessageDispatcher<R>>> {
143+
&self.message_sender as _
144+
}
145+
146+
async fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
147+
&self.error_stream as _
148+
}
149+
118150
/// Shuts down the transport, terminating tasks and signaling closure
119151
///
120152
/// Cancels any running tasks and clears the cancellation source.

0 commit comments

Comments
 (0)