@@ -310,19 +310,30 @@ func (s sseServerStream) Close() error {
310
310
// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports
311
311
type SSEClientTransport struct {
312
312
sseEndpoint * url.URL
313
+ httpClient * http.Client
313
314
}
314
315
315
316
// NewSSEClientTransport returns a new client transport that connects to the
316
- // SSE server at the provided URL.
317
+ // SSE server at the provided URL using the default HTTP client .
317
318
//
318
319
// NewSSEClientTransport panics if the given URL is invalid.
319
320
func NewSSEClientTransport (baseURL string ) * SSEClientTransport {
321
+ // Use the default HTTP client.
322
+ return NewSSEClientTransportWithHTTPClient (baseURL , http .DefaultClient )
323
+ }
324
+
325
+ // NewSSEClientTransportWithHTTPClient returns a new client transport that connects to the
326
+ // SSE server at the provided URL using the provided HTTP client.
327
+ //
328
+ // NewSSEClientTransportWithHTTPClient panics if the given URL is invalid.
329
+ func NewSSEClientTransportWithHTTPClient (baseURL string , httpClient * http.Client ) * SSEClientTransport {
320
330
url , err := url .Parse (baseURL )
321
331
if err != nil {
322
332
panic (fmt .Sprintf ("invalid base url: %v" , err ))
323
333
}
324
334
return & SSEClientTransport {
325
335
sseEndpoint : url ,
336
+ httpClient : httpClient ,
326
337
}
327
338
}
328
339
@@ -333,7 +344,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
333
344
return nil , err
334
345
}
335
346
req .Header .Set ("Accept" , "text/event-stream" )
336
- resp , err := http . DefaultClient .Do (req )
347
+ resp , err := c . httpClient .Do (req )
337
348
if err != nil {
338
349
return nil , err
339
350
}
@@ -404,6 +415,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
404
415
// From here on, the stream takes ownership of resp.Body.
405
416
s := & sseClientStream {
406
417
sseEndpoint : c .sseEndpoint ,
418
+ httpClient : c .httpClient ,
407
419
msgEndpoint : msgEndpoint ,
408
420
incoming : make (chan []byte , 100 ),
409
421
body : resp .Body ,
@@ -435,9 +447,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
435
447
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
436
448
// - Close terminates the GET request.
437
449
type sseClientStream struct {
438
- sseEndpoint * url.URL // SSE endpoint for the GET
439
- msgEndpoint * url.URL // session endpoint for POSTs
440
- incoming chan []byte // queue of incoming messages
450
+ sseEndpoint * url.URL // SSE endpoint for the GET
451
+ msgEndpoint * url.URL // session endpoint for POSTs
452
+ httpClient * http.Client // HTTP client to use for requests
453
+ incoming chan []byte // queue of incoming messages
441
454
442
455
mu sync.Mutex
443
456
body io.ReadCloser // body of the hanging GET
@@ -484,7 +497,7 @@ func (c *sseClientStream) Write(ctx context.Context, msg jsonrpc2.Message) error
484
497
return err
485
498
}
486
499
req .Header .Set ("Content-Type" , "application/json" )
487
- resp , err := http . DefaultClient .Do (req )
500
+ resp , err := c . httpClient .Do (req )
488
501
if err != nil {
489
502
return err
490
503
}
0 commit comments