Skip to content

Commit 1631b99

Browse files
authored
refactor: Go SDK usage improvements from module review (#2967)
## Summary Addresses the quick-win and best-practice findings from the Go Fan module review of `modelcontextprotocol/go-sdk` (#2911). ## Changes ### 1. Extract pagination helper (`connection.go`) Introduces a generic `paginateAll[T]()` helper that replaces three identical cursor-loop pagination implementations in `listTools`, `listResources`, and `listPrompts`. Eliminates ~45 lines of duplicated boilerplate while preserving identical logging behavior. ### 2. Eliminate `resourceContents` intermediate type (`tool_result.go`) The local `resourceContents` struct mirrored `sdk.ResourceContents` field-for-field for JSON unmarshaling. Since the SDK type works directly for this purpose, the local type is removed and `sdk.ResourceContents` is used inline. The field-by-field copy in the resource content conversion is also eliminated. ### 3. Explicit server options (`mcptest/server.go`) Passes `&sdk.ServerOptions{}` instead of `nil` to `sdk.NewServer()` — more defensive and documents intent, guarding against future SDK changes. ### 4. Cache eviction for `filteredServerCache` (`routed.go`) The routed-mode server cache previously grew unboundedly (one entry per backend×session pair, never evicted). Now includes: - TTL-based expiry matching the SDK `SessionTimeout` (30 minutes) - Lazy eviction on each `getOrCreate()` call - Extracted `routedSessionTimeout` variable shared between cache TTL and SDK options ### 5. Transport ownership documentation (`http_transport.go`) Adds lifecycle/ownership documentation to the `transportConnector` type, clarifying that the SDK session owns the returned transport after `Connect()`. ## Testing - All Go unit and integration tests pass - `make agent-finished` passes (format, build, lint, all tests) Closes #2911
2 parents b4d8da9 + 7e70592 commit 1631b99

5 files changed

Lines changed: 106 additions & 84 deletions

File tree

internal/mcp/connection.go

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -536,31 +536,56 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in
536536
return marshalToResponse(result)
537537
}
538538

539-
func (c *Connection) listTools() (*Response, error) {
540-
if err := c.requireSession(); err != nil {
541-
return nil, err
542-
}
543-
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
544-
// Fetch first page to determine initial capacity
545-
first, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{})
539+
// paginatedPage holds a single page of results from a paginated SDK list call.
540+
type paginatedPage[T any] struct {
541+
Items []T
542+
NextCursor string
543+
}
544+
545+
// paginateAll collects all items across paginated SDK list calls.
546+
func paginateAll[T any](
547+
serverID string,
548+
itemKind string,
549+
fetch func(cursor string) (paginatedPage[T], error),
550+
) ([]T, error) {
551+
first, err := fetch("")
546552
if err != nil {
547553
return nil, err
548554
}
549-
allTools := make([]*sdk.Tool, len(first.Tools), max(len(first.Tools), 1))
550-
copy(allTools, first.Tools)
551-
logConn.Printf("listTools: received page of %d tools from serverID=%s", len(first.Tools), c.serverID)
555+
all := make([]T, len(first.Items), max(len(first.Items), 1))
556+
copy(all, first.Items)
557+
logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID)
558+
552559
cursor := first.NextCursor
553560
for cursor != "" {
554-
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor})
561+
page, err := fetch(cursor)
555562
if err != nil {
556563
return nil, err
557564
}
558-
allTools = append(allTools, result.Tools...)
559-
logConn.Printf("listTools: received page of %d tools (total so far: %d) from serverID=%s", len(result.Tools), len(allTools), c.serverID)
560-
cursor = result.NextCursor
565+
all = append(all, page.Items...)
566+
logConn.Printf("list%s: received page of %d %s (total so far: %d) from serverID=%s", itemKind, len(page.Items), itemKind, len(all), serverID)
567+
cursor = page.NextCursor
561568
}
562-
logConn.Printf("listTools: received %d tools total from serverID=%s", len(allTools), c.serverID)
563-
return marshalToResponse(&sdk.ListToolsResult{Tools: allTools})
569+
logConn.Printf("list%s: received %d %s total from serverID=%s", itemKind, len(all), itemKind, serverID)
570+
return all, nil
571+
}
572+
573+
func (c *Connection) listTools() (*Response, error) {
574+
if err := c.requireSession(); err != nil {
575+
return nil, err
576+
}
577+
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
578+
tools, err := paginateAll(c.serverID, "tools", func(cursor string) (paginatedPage[*sdk.Tool], error) {
579+
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor})
580+
if err != nil {
581+
return paginatedPage[*sdk.Tool]{}, err
582+
}
583+
return paginatedPage[*sdk.Tool]{Items: result.Tools, NextCursor: result.NextCursor}, nil
584+
})
585+
if err != nil {
586+
return nil, err
587+
}
588+
return marshalToResponse(&sdk.ListToolsResult{Tools: tools})
564589
}
565590

566591
func (c *Connection) callTool(params interface{}) (*Response, error) {
@@ -583,26 +608,17 @@ func (c *Connection) listResources() (*Response, error) {
583608
return nil, err
584609
}
585610
logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID)
586-
// Fetch first page to determine initial capacity
587-
first, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{})
588-
if err != nil {
589-
return nil, err
590-
}
591-
allResources := make([]*sdk.Resource, len(first.Resources), max(len(first.Resources), 1))
592-
copy(allResources, first.Resources)
593-
logConn.Printf("listResources: received page of %d resources from serverID=%s", len(first.Resources), c.serverID)
594-
cursor := first.NextCursor
595-
for cursor != "" {
611+
resources, err := paginateAll(c.serverID, "resources", func(cursor string) (paginatedPage[*sdk.Resource], error) {
596612
result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{Cursor: cursor})
597613
if err != nil {
598-
return nil, err
614+
return paginatedPage[*sdk.Resource]{}, err
599615
}
600-
allResources = append(allResources, result.Resources...)
601-
logConn.Printf("listResources: received page of %d resources (total so far: %d) from serverID=%s", len(result.Resources), len(allResources), c.serverID)
602-
cursor = result.NextCursor
616+
return paginatedPage[*sdk.Resource]{Items: result.Resources, NextCursor: result.NextCursor}, nil
617+
})
618+
if err != nil {
619+
return nil, err
603620
}
604-
logConn.Printf("listResources: received %d resources total from serverID=%s", len(allResources), c.serverID)
605-
return marshalToResponse(&sdk.ListResourcesResult{Resources: allResources})
621+
return marshalToResponse(&sdk.ListResourcesResult{Resources: resources})
606622
}
607623

608624
func (c *Connection) readResource(params interface{}) (*Response, error) {
@@ -622,26 +638,17 @@ func (c *Connection) listPrompts() (*Response, error) {
622638
return nil, err
623639
}
624640
logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID)
625-
// Fetch first page to determine initial capacity
626-
first, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{})
627-
if err != nil {
628-
return nil, err
629-
}
630-
allPrompts := make([]*sdk.Prompt, len(first.Prompts), max(len(first.Prompts), 1))
631-
copy(allPrompts, first.Prompts)
632-
logConn.Printf("listPrompts: received page of %d prompts from serverID=%s", len(first.Prompts), c.serverID)
633-
cursor := first.NextCursor
634-
for cursor != "" {
641+
prompts, err := paginateAll(c.serverID, "prompts", func(cursor string) (paginatedPage[*sdk.Prompt], error) {
635642
result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{Cursor: cursor})
636643
if err != nil {
637-
return nil, err
644+
return paginatedPage[*sdk.Prompt]{}, err
638645
}
639-
allPrompts = append(allPrompts, result.Prompts...)
640-
logConn.Printf("listPrompts: received page of %d prompts (total so far: %d) from serverID=%s", len(result.Prompts), len(allPrompts), c.serverID)
641-
cursor = result.NextCursor
646+
return paginatedPage[*sdk.Prompt]{Items: result.Prompts, NextCursor: result.NextCursor}, nil
647+
})
648+
if err != nil {
649+
return nil, err
642650
}
643-
logConn.Printf("listPrompts: received %d prompts total from serverID=%s", len(allPrompts), c.serverID)
644-
return marshalToResponse(&sdk.ListPromptsResult{Prompts: allPrompts})
651+
return marshalToResponse(&sdk.ListPromptsResult{Prompts: prompts})
645652
}
646653

647654
func (c *Connection) getPrompt(params interface{}) (*Response, error) {

internal/mcp/http_transport.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ type httpRequestResult struct {
4848
Header http.Header
4949
}
5050

51-
// transportConnector is a function that creates an SDK transport for a given URL and HTTP client
51+
// transportConnector is a function that creates an SDK transport for a given URL and HTTP client.
52+
// The returned transport is owned by the SDK client session after Connect() succeeds;
53+
// callers must not close it directly — it is cleaned up when the session is closed.
5254
type transportConnector func(url string, httpClient *http.Client) sdk.Transport
5355

5456
// isHTTPConnectionError checks if an error is a network connection error.

internal/mcp/tool_result.go

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,6 @@ import (
88
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
99
)
1010

11-
// resourceContents mirrors sdk.ResourceContents for JSON unmarshaling of
12-
// embedded resource content items returned by backend MCP servers.
13-
type resourceContents struct {
14-
URI string `json:"uri"`
15-
MIMEType string `json:"mimeType,omitempty"`
16-
Text string `json:"text,omitempty"`
17-
Blob []byte `json:"blob,omitempty"`
18-
}
19-
2011
var logToolResult = logger.New("mcp:tool_result")
2112

2213
// ConvertToCallToolResult converts backend result data to SDK CallToolResult format.
@@ -70,11 +61,11 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) {
7061
// Parse the backend result structure (standard MCP CallToolResult format)
7162
var backendResult struct {
7263
Content []struct {
73-
Type string `json:"type"`
74-
Text string `json:"text,omitempty"`
75-
Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON)
76-
MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type
77-
Resource *resourceContents `json:"resource,omitempty"` // embedded resource
64+
Type string `json:"type"`
65+
Text string `json:"text,omitempty"`
66+
Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON)
67+
MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type
68+
Resource *sdk.ResourceContents `json:"resource,omitempty"` // embedded resource
7869
} `json:"content"`
7970
IsError bool `json:"isError,omitempty"`
8071
}
@@ -114,12 +105,7 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) {
114105
case "resource":
115106
if item.Resource != nil {
116107
content = append(content, &sdk.EmbeddedResource{
117-
Resource: &sdk.ResourceContents{
118-
URI: item.Resource.URI,
119-
MIMEType: item.Resource.MIMEType,
120-
Text: item.Resource.Text,
121-
Blob: item.Resource.Blob,
122-
},
108+
Resource: item.Resource,
123109
})
124110
} else {
125111
logToolResult.Printf("Resource content item missing 'resource' field, skipping")

internal/server/routed.go

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"time"
1111

1212
"github.com/github/gh-aw-mcpg/internal/logger"
13-
"github.com/github/gh-aw-mcpg/internal/syncutil"
1413
"github.com/github/gh-aw-mcpg/internal/version"
1514
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
1615
)
@@ -31,27 +30,53 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp
3130
})
3231
}
3332

34-
// filteredServerCache caches filtered server instances per (backend, session) key
33+
// filteredServerCache caches filtered server instances per (backend, session) key.
34+
// Entries are evicted after the configured TTL to prevent unbounded memory growth
35+
// in long-running deployments with many sessions.
3536
type filteredServerCache struct {
36-
servers map[string]*sdk.Server
37+
servers map[string]*filteredServerEntry
38+
ttl time.Duration
3739
mu sync.RWMutex
3840
}
3941

40-
// newFilteredServerCache creates a new server cache
41-
func newFilteredServerCache() *filteredServerCache {
42+
type filteredServerEntry struct {
43+
server *sdk.Server
44+
lastUsed time.Time
45+
}
46+
47+
// newFilteredServerCache creates a new server cache with the given entry TTL.
48+
func newFilteredServerCache(ttl time.Duration) *filteredServerCache {
4249
return &filteredServerCache{
43-
servers: make(map[string]*sdk.Server),
50+
servers: make(map[string]*filteredServerEntry),
51+
ttl: ttl,
4452
}
4553
}
4654

47-
// getOrCreate returns a cached server or creates a new one
55+
// getOrCreate returns a cached server or creates a new one.
56+
// Expired entries are lazily evicted on each call.
4857
func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server {
4958
key := fmt.Sprintf("%s/%s", backendID, sessionID)
59+
now := time.Now()
5060

51-
server, _ := syncutil.GetOrCreate(&c.mu, c.servers, key, func() (*sdk.Server, error) {
52-
logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID)
53-
return creator(), nil
54-
})
61+
c.mu.Lock()
62+
defer c.mu.Unlock()
63+
64+
// Lazy eviction of expired entries
65+
for k, entry := range c.servers {
66+
if now.Sub(entry.lastUsed) > c.ttl {
67+
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", k, now.Sub(entry.lastUsed).Round(time.Second))
68+
delete(c.servers, k)
69+
}
70+
}
71+
72+
if entry, ok := c.servers[key]; ok {
73+
entry.lastUsed = now
74+
return entry.server
75+
}
76+
77+
logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID)
78+
server := creator()
79+
c.servers[key] = &filteredServerEntry{server: server, lastUsed: now}
5580
return server
5681
}
5782

@@ -71,8 +96,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
7196
allBackends := unifiedServer.GetServerIDs()
7297
logRouted.Printf("Registering routes for %d backends: %v", len(allBackends), allBackends)
7398

74-
// Create server cache for session-aware server instances
75-
serverCache := newFilteredServerCache()
99+
// Create server cache for session-aware server instances.
100+
// TTL matches the SDK SessionTimeout so cache entries expire with sessions.
101+
routedSessionTimeout := 30 * time.Minute
102+
serverCache := newFilteredServerCache(routedSessionTimeout)
76103

77104
// Create a proxy for each backend server
78105
for _, serverID := range allBackends {
@@ -95,7 +122,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
95122
}, &sdk.StreamableHTTPOptions{
96123
Stateless: false,
97124
Logger: logger.NewSlogLoggerWithHandler(logRouted),
98-
SessionTimeout: 30 * time.Minute,
125+
SessionTimeout: routedSessionTimeout,
99126
})
100127

101128
// Apply standard middleware stack (SDK logging → shutdown check → auth)

internal/testutil/mcptest/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (s *Server) Start() error {
3636
Version: s.config.Version,
3737
}
3838

39-
s.server = sdk.NewServer(impl, nil)
39+
s.server = sdk.NewServer(impl, &sdk.ServerOptions{})
4040

4141
// Register tools
4242
for i, toolCfg := range s.config.Tools {

0 commit comments

Comments
 (0)