Skip to content

Commit a5ac0b3

Browse files
committed
refactor: make CallToolRequest.Arguments more flexible
This change allows for more flexible argument types in CallToolRequest. Arguments field is now of type 'any' instead of 'map[string]any'. Added GetArguments() and GetRawArguments() methods for backward compatibility. Fixes #104
1 parent e767652 commit a5ac0b3

File tree

10 files changed

+138
-21
lines changed

10 files changed

+138
-21
lines changed

README.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,10 @@ func main() {
149149

150150
// Add the calculator handler
151151
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
152-
op := request.Params.Arguments["operation"].(string)
153-
x := request.Params.Arguments["x"].(float64)
154-
y := request.Params.Arguments["y"].(float64)
152+
args := request.GetArguments()
153+
op := args["operation"].(string)
154+
x := args["x"].(float64)
155+
y := args["y"].(float64)
155156

156157
var result float64
157158
switch op {
@@ -312,9 +313,10 @@ calculatorTool := mcp.NewTool("calculate",
312313
)
313314

314315
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
315-
op := request.Params.Arguments["operation"].(string)
316-
x := request.Params.Arguments["x"].(float64)
317-
y := request.Params.Arguments["y"].(float64)
316+
args := request.GetArguments()
317+
op := args["operation"].(string)
318+
x := args["x"].(float64)
319+
y := args["y"].(float64)
318320

319321
var result float64
320322
switch op {
@@ -355,10 +357,11 @@ httpTool := mcp.NewTool("http_request",
355357
)
356358

357359
s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
358-
method := request.Params.Arguments["method"].(string)
359-
url := request.Params.Arguments["url"].(string)
360+
args := request.GetArguments()
361+
method := args["method"].(string)
362+
url := args["url"].(string)
360363
body := ""
361-
if b, ok := request.Params.Arguments["body"].(string); ok {
364+
if b, ok := args["body"].(string); ok {
362365
body = b
363366
}
364367

client/inprocess_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func TestInProcessMCPClient(t *testing.T) {
3232
Content: []mcp.Content{
3333
mcp.TextContent{
3434
Type: "text",
35-
Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string),
35+
Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string),
3636
},
3737
mcp.AudioContent{
3838
Type: "audio",

client/sse_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestSSEMCPClient(t *testing.T) {
3636
Content: []mcp.Content{
3737
mcp.TextContent{
3838
Type: "text",
39-
Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string),
39+
Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string),
4040
},
4141
},
4242
}, nil

examples/custom_context/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func handleMakeAuthenticatedRequestTool(
8181
ctx context.Context,
8282
request mcp.CallToolRequest,
8383
) (*mcp.CallToolResult, error) {
84-
message, ok := request.Params.Arguments["message"].(string)
84+
message, ok := request.GetArguments()["message"].(string)
8585
if !ok {
8686
return nil, fmt.Errorf("missing message")
8787
}

examples/dynamic_path/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func main() {
2020

2121
// Add a trivial tool for demonstration
2222
mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
23-
return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.Params.Arguments["message"])), nil
23+
return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil
2424
})
2525

2626
// Use a dynamic base path based on a path parameter (Go 1.22+)

examples/everything/main.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ func handleEchoTool(
312312
ctx context.Context,
313313
request mcp.CallToolRequest,
314314
) (*mcp.CallToolResult, error) {
315-
arguments := request.Params.Arguments
315+
arguments := request.GetArguments()
316316
message, ok := arguments["message"].(string)
317317
if !ok {
318318
return nil, fmt.Errorf("invalid message argument")
@@ -331,7 +331,7 @@ func handleAddTool(
331331
ctx context.Context,
332332
request mcp.CallToolRequest,
333333
) (*mcp.CallToolResult, error) {
334-
arguments := request.Params.Arguments
334+
arguments := request.GetArguments()
335335
a, ok1 := arguments["a"].(float64)
336336
b, ok2 := arguments["b"].(float64)
337337
if !ok1 || !ok2 {
@@ -382,7 +382,7 @@ func handleLongRunningOperationTool(
382382
ctx context.Context,
383383
request mcp.CallToolRequest,
384384
) (*mcp.CallToolResult, error) {
385-
arguments := request.Params.Arguments
385+
arguments := request.GetArguments()
386386
progressToken := request.Params.Meta.ProgressToken
387387
duration, _ := arguments["duration"].(float64)
388388
steps, _ := arguments["steps"].(float64)

mcp/tools.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ type CallToolResult struct {
4444
type CallToolRequest struct {
4545
Request
4646
Params struct {
47-
Name string `json:"name"`
48-
Arguments map[string]any `json:"arguments,omitempty"`
47+
Name string `json:"name"`
48+
Arguments any `json:"arguments,omitempty"` // Can be map[string]any or any other type
4949
Meta *struct {
5050
// If specified, the caller is requesting out-of-band progress
5151
// notifications for this request (as represented by
@@ -58,6 +58,21 @@ type CallToolRequest struct {
5858
} `json:"params"`
5959
}
6060

61+
// GetArguments returns the Arguments as map[string]any for backward compatibility
62+
// If Arguments is not a map, it returns an empty map
63+
func (r CallToolRequest) GetArguments() map[string]any {
64+
if args, ok := r.Params.Arguments.(map[string]any); ok {
65+
return args
66+
}
67+
return map[string]any{}
68+
}
69+
70+
// GetRawArguments returns the Arguments as-is without type conversion
71+
// This allows users to access the raw arguments in any format
72+
func (r CallToolRequest) GetRawArguments() any {
73+
return r.Params.Arguments
74+
}
75+
6176
// ToolListChangedNotification is an optional notification from the server to
6277
// the client, informing it that the list of tools it offers has changed. This may
6378
// be issued by servers without any previous subscription from the client.

mcp/tools_arguments_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package mcp
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestCallToolRequestWithMapArguments(t *testing.T) {
11+
// Create a request with map arguments
12+
req := CallToolRequest{}
13+
req.Params.Name = "test-tool"
14+
req.Params.Arguments = map[string]any{
15+
"key1": "value1",
16+
"key2": 123,
17+
}
18+
19+
// Test GetArguments
20+
args := req.GetArguments()
21+
assert.Equal(t, "value1", args["key1"])
22+
assert.Equal(t, 123, args["key2"])
23+
24+
// Test GetRawArguments
25+
rawArgs := req.GetRawArguments()
26+
mapArgs, ok := rawArgs.(map[string]any)
27+
assert.True(t, ok)
28+
assert.Equal(t, "value1", mapArgs["key1"])
29+
assert.Equal(t, 123, mapArgs["key2"])
30+
}
31+
32+
func TestCallToolRequestWithNonMapArguments(t *testing.T) {
33+
// Create a request with non-map arguments
34+
req := CallToolRequest{}
35+
req.Params.Name = "test-tool"
36+
req.Params.Arguments = "string-argument"
37+
38+
// Test GetArguments (should return empty map)
39+
args := req.GetArguments()
40+
assert.Empty(t, args)
41+
42+
// Test GetRawArguments
43+
rawArgs := req.GetRawArguments()
44+
strArg, ok := rawArgs.(string)
45+
assert.True(t, ok)
46+
assert.Equal(t, "string-argument", strArg)
47+
}
48+
49+
func TestCallToolRequestWithStructArguments(t *testing.T) {
50+
// Create a custom struct
51+
type CustomArgs struct {
52+
Field1 string `json:"field1"`
53+
Field2 int `json:"field2"`
54+
}
55+
56+
// Create a request with struct arguments
57+
req := CallToolRequest{}
58+
req.Params.Name = "test-tool"
59+
req.Params.Arguments = CustomArgs{
60+
Field1: "test",
61+
Field2: 42,
62+
}
63+
64+
// Test GetArguments (should return empty map)
65+
args := req.GetArguments()
66+
assert.Empty(t, args)
67+
68+
// Test GetRawArguments
69+
rawArgs := req.GetRawArguments()
70+
structArg, ok := rawArgs.(CustomArgs)
71+
assert.True(t, ok)
72+
assert.Equal(t, "test", structArg.Field1)
73+
assert.Equal(t, 42, structArg.Field2)
74+
}
75+
76+
func TestCallToolRequestJSONMarshalUnmarshal(t *testing.T) {
77+
// Create a request with map arguments
78+
req := CallToolRequest{}
79+
req.Params.Name = "test-tool"
80+
req.Params.Arguments = map[string]any{
81+
"key1": "value1",
82+
"key2": 123,
83+
}
84+
85+
// Marshal to JSON
86+
data, err := json.Marshal(req)
87+
assert.NoError(t, err)
88+
89+
// Unmarshal from JSON
90+
var unmarshaledReq CallToolRequest
91+
err = json.Unmarshal(data, &unmarshaledReq)
92+
assert.NoError(t, err)
93+
94+
// Check if arguments are correctly unmarshaled
95+
args := unmarshaledReq.GetArguments()
96+
assert.Equal(t, "value1", args["key1"])
97+
assert.Equal(t, float64(123), args["key2"]) // JSON numbers are unmarshaled as float64
98+
}

mcp/utils.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,10 +675,11 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult,
675675
}
676676

677677
func ParseArgument(request CallToolRequest, key string, defaultVal any) any {
678-
if _, ok := request.Params.Arguments[key]; !ok {
678+
args := request.GetArguments()
679+
if _, ok := args[key]; !ok {
679680
return defaultVal
680681
} else {
681-
return request.Params.Arguments[key]
682+
return args[key]
682683
}
683684
}
684685

mcptest/mcptest_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func TestServer(t *testing.T) {
5252

5353
func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
5454
// Extract name from request arguments
55-
name, ok := request.Params.Arguments["name"].(string)
55+
name, ok := request.GetArguments()["name"].(string)
5656
if !ok {
5757
name = "World"
5858
}

0 commit comments

Comments
 (0)