Skip to content

Commit 7e70592

Browse files
lpcoxCopilot
andcommitted
refactor: go-sdk usage improvements from module review
Address findings from Go Fan report on modelcontextprotocol/go-sdk: 1. Extract generic paginateAll() helper in connection.go to deduplicate identical cursor-loop pagination across listTools, listResources, and listPrompts (~45 lines of boilerplate removed). 2. Eliminate resourceContents intermediate type in tool_result.go by using sdk.ResourceContents directly for JSON unmarshaling, removing field-by-field copy in the resource content conversion. 3. Pass explicit &sdk.ServerOptions{} instead of nil in mcptest/server.go to guard against future SDK changes that might not accept nil options. 4. Add TTL-based eviction to filteredServerCache in routed.go to prevent unbounded memory growth. Cache entries now expire after the session timeout (30min), evicted lazily on each getOrCreate call. 5. Add transport ownership documentation to transportConnector type clarifying that the SDK session owns the transport after Connect(). Closes #2911 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent b4d8da9 commit 7e70592

5 files changed

Lines changed: 106 additions & 84 deletions

File tree

internal/mcp/connection.go

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -536,31 +536,56 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in
536536
return marshalToResponse(result)
537537
}
538538

539-
func (c *Connection) listTools() (*Response, error) {
540-
if err := c.requireSession(); err != nil {
541-
return nil, err
542-
}
543-
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
544-
// Fetch first page to determine initial capacity
545-
first, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{})
539+
// paginatedPage holds a single page of results from a paginated SDK list call.
540+
type paginatedPage[T any] struct {
541+
Items []T
542+
NextCursor string
543+
}
544+
545+
// paginateAll collects all items across paginated SDK list calls.
546+
func paginateAll[T any](
547+
serverID string,
548+
itemKind string,
549+
fetch func(cursor string) (paginatedPage[T], error),
550+
) ([]T, error) {
551+
first, err := fetch("")
546552
if err != nil {
547553
return nil, err
548554
}
549-
allTools := make([]*sdk.Tool, len(first.Tools), max(len(first.Tools), 1))
550-
copy(allTools, first.Tools)
551-
logConn.Printf("listTools: received page of %d tools from serverID=%s", len(first.Tools), c.serverID)
555+
all := make([]T, len(first.Items), max(len(first.Items), 1))
556+
copy(all, first.Items)
557+
logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID)
558+
552559
cursor := first.NextCursor
553560
for cursor != "" {
554-
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor})
561+
page, err := fetch(cursor)
555562
if err != nil {
556563
return nil, err
557564
}
558-
allTools = append(allTools, result.Tools...)
559-
logConn.Printf("listTools: received page of %d tools (total so far: %d) from serverID=%s", len(result.Tools), len(allTools), c.serverID)
560-
cursor = result.NextCursor
565+
all = append(all, page.Items...)
566+
logConn.Printf("list%s: received page of %d %s (total so far: %d) from serverID=%s", itemKind, len(page.Items), itemKind, len(all), serverID)
567+
cursor = page.NextCursor
561568
}
562-
logConn.Printf("listTools: received %d tools total from serverID=%s", len(allTools), c.serverID)
563-
return marshalToResponse(&sdk.ListToolsResult{Tools: allTools})
569+
logConn.Printf("list%s: received %d %s total from serverID=%s", itemKind, len(all), itemKind, serverID)
570+
return all, nil
571+
}
572+
573+
func (c *Connection) listTools() (*Response, error) {
574+
if err := c.requireSession(); err != nil {
575+
return nil, err
576+
}
577+
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
578+
tools, err := paginateAll(c.serverID, "tools", func(cursor string) (paginatedPage[*sdk.Tool], error) {
579+
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor})
580+
if err != nil {
581+
return paginatedPage[*sdk.Tool]{}, err
582+
}
583+
return paginatedPage[*sdk.Tool]{Items: result.Tools, NextCursor: result.NextCursor}, nil
584+
})
585+
if err != nil {
586+
return nil, err
587+
}
588+
return marshalToResponse(&sdk.ListToolsResult{Tools: tools})
564589
}
565590

566591
func (c *Connection) callTool(params interface{}) (*Response, error) {
@@ -583,26 +608,17 @@ func (c *Connection) listResources() (*Response, error) {
583608
return nil, err
584609
}
585610
logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID)
586-
// Fetch first page to determine initial capacity
587-
first, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{})
588-
if err != nil {
589-
return nil, err
590-
}
591-
allResources := make([]*sdk.Resource, len(first.Resources), max(len(first.Resources), 1))
592-
copy(allResources, first.Resources)
593-
logConn.Printf("listResources: received page of %d resources from serverID=%s", len(first.Resources), c.serverID)
594-
cursor := first.NextCursor
595-
for cursor != "" {
611+
resources, err := paginateAll(c.serverID, "resources", func(cursor string) (paginatedPage[*sdk.Resource], error) {
596612
result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{Cursor: cursor})
597613
if err != nil {
598-
return nil, err
614+
return paginatedPage[*sdk.Resource]{}, err
599615
}
600-
allResources = append(allResources, result.Resources...)
601-
logConn.Printf("listResources: received page of %d resources (total so far: %d) from serverID=%s", len(result.Resources), len(allResources), c.serverID)
602-
cursor = result.NextCursor
616+
return paginatedPage[*sdk.Resource]{Items: result.Resources, NextCursor: result.NextCursor}, nil
617+
})
618+
if err != nil {
619+
return nil, err
603620
}
604-
logConn.Printf("listResources: received %d resources total from serverID=%s", len(allResources), c.serverID)
605-
return marshalToResponse(&sdk.ListResourcesResult{Resources: allResources})
621+
return marshalToResponse(&sdk.ListResourcesResult{Resources: resources})
606622
}
607623

608624
func (c *Connection) readResource(params interface{}) (*Response, error) {
@@ -622,26 +638,17 @@ func (c *Connection) listPrompts() (*Response, error) {
622638
return nil, err
623639
}
624640
logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID)
625-
// Fetch first page to determine initial capacity
626-
first, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{})
627-
if err != nil {
628-
return nil, err
629-
}
630-
allPrompts := make([]*sdk.Prompt, len(first.Prompts), max(len(first.Prompts), 1))
631-
copy(allPrompts, first.Prompts)
632-
logConn.Printf("listPrompts: received page of %d prompts from serverID=%s", len(first.Prompts), c.serverID)
633-
cursor := first.NextCursor
634-
for cursor != "" {
641+
prompts, err := paginateAll(c.serverID, "prompts", func(cursor string) (paginatedPage[*sdk.Prompt], error) {
635642
result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{Cursor: cursor})
636643
if err != nil {
637-
return nil, err
644+
return paginatedPage[*sdk.Prompt]{}, err
638645
}
639-
allPrompts = append(allPrompts, result.Prompts...)
640-
logConn.Printf("listPrompts: received page of %d prompts (total so far: %d) from serverID=%s", len(result.Prompts), len(allPrompts), c.serverID)
641-
cursor = result.NextCursor
646+
return paginatedPage[*sdk.Prompt]{Items: result.Prompts, NextCursor: result.NextCursor}, nil
647+
})
648+
if err != nil {
649+
return nil, err
642650
}
643-
logConn.Printf("listPrompts: received %d prompts total from serverID=%s", len(allPrompts), c.serverID)
644-
return marshalToResponse(&sdk.ListPromptsResult{Prompts: allPrompts})
651+
return marshalToResponse(&sdk.ListPromptsResult{Prompts: prompts})
645652
}
646653

647654
func (c *Connection) getPrompt(params interface{}) (*Response, error) {

internal/mcp/http_transport.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ type httpRequestResult struct {
4848
Header http.Header
4949
}
5050

51-
// transportConnector is a function that creates an SDK transport for a given URL and HTTP client
51+
// transportConnector is a function that creates an SDK transport for a given URL and HTTP client.
52+
// The returned transport is owned by the SDK client session after Connect() succeeds;
53+
// callers must not close it directly — it is cleaned up when the session is closed.
5254
type transportConnector func(url string, httpClient *http.Client) sdk.Transport
5355

5456
// isHTTPConnectionError checks if an error is a network connection error.

internal/mcp/tool_result.go

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,6 @@ import (
88
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
99
)
1010

11-
// resourceContents mirrors sdk.ResourceContents for JSON unmarshaling of
12-
// embedded resource content items returned by backend MCP servers.
13-
type resourceContents struct {
14-
URI string `json:"uri"`
15-
MIMEType string `json:"mimeType,omitempty"`
16-
Text string `json:"text,omitempty"`
17-
Blob []byte `json:"blob,omitempty"`
18-
}
19-
2011
var logToolResult = logger.New("mcp:tool_result")
2112

2213
// ConvertToCallToolResult converts backend result data to SDK CallToolResult format.
@@ -70,11 +61,11 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) {
7061
// Parse the backend result structure (standard MCP CallToolResult format)
7162
var backendResult struct {
7263
Content []struct {
73-
Type string `json:"type"`
74-
Text string `json:"text,omitempty"`
75-
Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON)
76-
MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type
77-
Resource *resourceContents `json:"resource,omitempty"` // embedded resource
64+
Type string `json:"type"`
65+
Text string `json:"text,omitempty"`
66+
Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON)
67+
MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type
68+
Resource *sdk.ResourceContents `json:"resource,omitempty"` // embedded resource
7869
} `json:"content"`
7970
IsError bool `json:"isError,omitempty"`
8071
}
@@ -114,12 +105,7 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) {
114105
case "resource":
115106
if item.Resource != nil {
116107
content = append(content, &sdk.EmbeddedResource{
117-
Resource: &sdk.ResourceContents{
118-
URI: item.Resource.URI,
119-
MIMEType: item.Resource.MIMEType,
120-
Text: item.Resource.Text,
121-
Blob: item.Resource.Blob,
122-
},
108+
Resource: item.Resource,
123109
})
124110
} else {
125111
logToolResult.Printf("Resource content item missing 'resource' field, skipping")

internal/server/routed.go

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"time"
1111

1212
"github.com/github/gh-aw-mcpg/internal/logger"
13-
"github.com/github/gh-aw-mcpg/internal/syncutil"
1413
"github.com/github/gh-aw-mcpg/internal/version"
1514
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
1615
)
@@ -31,27 +30,53 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp
3130
})
3231
}
3332

34-
// filteredServerCache caches filtered server instances per (backend, session) key
33+
// filteredServerCache caches filtered server instances per (backend, session) key.
34+
// Entries are evicted after the configured TTL to prevent unbounded memory growth
35+
// in long-running deployments with many sessions.
3536
type filteredServerCache struct {
36-
servers map[string]*sdk.Server
37+
servers map[string]*filteredServerEntry
38+
ttl time.Duration
3739
mu sync.RWMutex
3840
}
3941

40-
// newFilteredServerCache creates a new server cache
41-
func newFilteredServerCache() *filteredServerCache {
42+
type filteredServerEntry struct {
43+
server *sdk.Server
44+
lastUsed time.Time
45+
}
46+
47+
// newFilteredServerCache creates a new server cache with the given entry TTL.
48+
func newFilteredServerCache(ttl time.Duration) *filteredServerCache {
4249
return &filteredServerCache{
43-
servers: make(map[string]*sdk.Server),
50+
servers: make(map[string]*filteredServerEntry),
51+
ttl: ttl,
4452
}
4553
}
4654

47-
// getOrCreate returns a cached server or creates a new one
55+
// getOrCreate returns a cached server or creates a new one.
56+
// Expired entries are lazily evicted on each call.
4857
func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server {
4958
key := fmt.Sprintf("%s/%s", backendID, sessionID)
59+
now := time.Now()
5060

51-
server, _ := syncutil.GetOrCreate(&c.mu, c.servers, key, func() (*sdk.Server, error) {
52-
logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID)
53-
return creator(), nil
54-
})
61+
c.mu.Lock()
62+
defer c.mu.Unlock()
63+
64+
// Lazy eviction of expired entries
65+
for k, entry := range c.servers {
66+
if now.Sub(entry.lastUsed) > c.ttl {
67+
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", k, now.Sub(entry.lastUsed).Round(time.Second))
68+
delete(c.servers, k)
69+
}
70+
}
71+
72+
if entry, ok := c.servers[key]; ok {
73+
entry.lastUsed = now
74+
return entry.server
75+
}
76+
77+
logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID)
78+
server := creator()
79+
c.servers[key] = &filteredServerEntry{server: server, lastUsed: now}
5580
return server
5681
}
5782

@@ -71,8 +96,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
7196
allBackends := unifiedServer.GetServerIDs()
7297
logRouted.Printf("Registering routes for %d backends: %v", len(allBackends), allBackends)
7398

74-
// Create server cache for session-aware server instances
75-
serverCache := newFilteredServerCache()
99+
// Create server cache for session-aware server instances.
100+
// TTL matches the SDK SessionTimeout so cache entries expire with sessions.
101+
routedSessionTimeout := 30 * time.Minute
102+
serverCache := newFilteredServerCache(routedSessionTimeout)
76103

77104
// Create a proxy for each backend server
78105
for _, serverID := range allBackends {
@@ -95,7 +122,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
95122
}, &sdk.StreamableHTTPOptions{
96123
Stateless: false,
97124
Logger: logger.NewSlogLoggerWithHandler(logRouted),
98-
SessionTimeout: 30 * time.Minute,
125+
SessionTimeout: routedSessionTimeout,
99126
})
100127

101128
// Apply standard middleware stack (SDK logging → shutdown check → auth)

internal/testutil/mcptest/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (s *Server) Start() error {
3636
Version: s.config.Version,
3737
}
3838

39-
s.server = sdk.NewServer(impl, nil)
39+
s.server = sdk.NewServer(impl, &sdk.ServerOptions{})
4040

4141
// Register tools
4242
for i, toolCfg := range s.config.Tools {

0 commit comments

Comments
 (0)