Skip to content

Commit 83ed676

Browse files
committed
Add the ability to provide a custom HTTP client for SSE connections
1 parent 439ea2f commit 83ed676

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

internal/mcp/sse.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,19 +310,30 @@ func (s sseServerStream) Close() error {
310310
// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports
311311
type SSEClientTransport struct {
312312
sseEndpoint *url.URL
313+
httpClient *http.Client
313314
}
314315

315316
// NewSSEClientTransport returns a new client transport that connects to the
316-
// SSE server at the provided URL.
317+
// SSE server at the provided URL using the default HTTP client.
317318
//
318319
// NewSSEClientTransport panics if the given URL is invalid.
319320
func NewSSEClientTransport(baseURL string) *SSEClientTransport {
321+
// Use the default HTTP client.
322+
return NewSSEClientTransportWithHTTPClient(baseURL, http.DefaultClient)
323+
}
324+
325+
// NewSSEClientTransportWithHTTPClient returns a new client transport that connects to the
326+
// SSE server at the provided URL using the provided HTTP client.
327+
//
328+
// NewSSEClientTransportWithHTTPClient panics if the given URL is invalid.
329+
func NewSSEClientTransportWithHTTPClient(baseURL string, httpClient *http.Client) *SSEClientTransport {
320330
url, err := url.Parse(baseURL)
321331
if err != nil {
322332
panic(fmt.Sprintf("invalid base url: %v", err))
323333
}
324334
return &SSEClientTransport{
325335
sseEndpoint: url,
336+
httpClient: httpClient,
326337
}
327338
}
328339

@@ -333,7 +344,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
333344
return nil, err
334345
}
335346
req.Header.Set("Accept", "text/event-stream")
336-
resp, err := http.DefaultClient.Do(req)
347+
resp, err := c.httpClient.Do(req)
337348
if err != nil {
338349
return nil, err
339350
}
@@ -404,6 +415,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
404415
// From here on, the stream takes ownership of resp.Body.
405416
s := &sseClientStream{
406417
sseEndpoint: c.sseEndpoint,
418+
httpClient: c.httpClient,
407419
msgEndpoint: msgEndpoint,
408420
incoming: make(chan []byte, 100),
409421
body: resp.Body,
@@ -435,9 +447,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
435447
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
436448
// - Close terminates the GET request.
437449
type sseClientStream struct {
438-
sseEndpoint *url.URL // SSE endpoint for the GET
439-
msgEndpoint *url.URL // session endpoint for POSTs
440-
incoming chan []byte // queue of incoming messages
450+
sseEndpoint *url.URL // SSE endpoint for the GET
451+
msgEndpoint *url.URL // session endpoint for POSTs
452+
httpClient *http.Client // HTTP client to use for requests
453+
incoming chan []byte // queue of incoming messages
441454

442455
mu sync.Mutex
443456
body io.ReadCloser // body of the hanging GET
@@ -484,7 +497,7 @@ func (c *sseClientStream) Write(ctx context.Context, msg jsonrpc2.Message) error
484497
return err
485498
}
486499
req.Header.Set("Content-Type", "application/json")
487-
resp, err := http.DefaultClient.Do(req)
500+
resp, err := c.httpClient.Do(req)
488501
if err != nil {
489502
return err
490503
}

0 commit comments

Comments
 (0)