|
7 | 7 | "net/http/httptest" |
8 | 8 | "strings" |
9 | 9 | "testing" |
| 10 | + "time" |
10 | 11 |
|
11 | 12 | "github.com/github/gh-aw-mcpg/internal/config" |
12 | 13 | sdk "github.com/modelcontextprotocol/go-sdk/mcp" |
@@ -484,3 +485,72 @@ func TestCallBackendTool_AllowedToolsError_MessageFormat(t *testing.T) { |
484 | 485 | assert.True(t, strings.Contains(text, `"blocked"`), "error message should include tool name: %s", text) |
485 | 486 | assert.True(t, strings.Contains(text, "allowed-tools"), "error message should mention allowed-tools: %s", text) |
486 | 487 | } |
| 488 | + |
| 489 | +// TestCallBackendTool_ToolTimeoutEnforcedViaContext verifies that the configured |
| 490 | +// toolTimeout is applied as a context deadline, causing slow backend calls to fail |
| 491 | +// with a deadline-exceeded error instead of hanging indefinitely. |
| 492 | +func TestCallBackendTool_ToolTimeoutEnforcedViaContext(t *testing.T) { |
| 493 | + require := require.New(t) |
| 494 | + |
| 495 | + // Create a slow backend that delays longer than our tool timeout |
| 496 | + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 497 | + var req map[string]interface{} |
| 498 | + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { |
| 499 | + w.WriteHeader(http.StatusBadRequest) |
| 500 | + return |
| 501 | + } |
| 502 | + |
| 503 | + method, _ := req["method"].(string) |
| 504 | + switch method { |
| 505 | + case "initialize": |
| 506 | + json.NewEncoder(w).Encode(map[string]interface{}{ |
| 507 | + "jsonrpc": "2.0", "id": req["id"], |
| 508 | + "result": map[string]interface{}{ |
| 509 | + "protocolVersion": "2024-11-05", |
| 510 | + "capabilities": map[string]interface{}{}, |
| 511 | + "serverInfo": map[string]interface{}{"name": "slow-backend", "version": "1.0.0"}, |
| 512 | + }, |
| 513 | + }) |
| 514 | + case "tools/list": |
| 515 | + json.NewEncoder(w).Encode(map[string]interface{}{ |
| 516 | + "jsonrpc": "2.0", "id": req["id"], |
| 517 | + "result": map[string]interface{}{"tools": []map[string]interface{}{}}, |
| 518 | + }) |
| 519 | + case "tools/call": |
| 520 | + // Simulate a slow tool: sleep longer than the configured toolTimeout. |
| 521 | + // The actual tool call should return after ~1s (timeout), but the |
| 522 | + // httptest.Server cleanup waits for this goroutine to finish. |
| 523 | + time.Sleep(3 * time.Second) |
| 524 | + json.NewEncoder(w).Encode(map[string]interface{}{ |
| 525 | + "jsonrpc": "2.0", "id": req["id"], |
| 526 | + "result": map[string]interface{}{ |
| 527 | + "content": []map[string]interface{}{{"type": "text", "text": "should not reach here"}}, |
| 528 | + }, |
| 529 | + }) |
| 530 | + } |
| 531 | + })) |
| 532 | + defer backend.Close() |
| 533 | + |
| 534 | + // Configure with a very short toolTimeout (1 second) |
| 535 | + cfg := &config.Config{ |
| 536 | + Gateway: &config.GatewayConfig{ |
| 537 | + ToolTimeout: 1, |
| 538 | + }, |
| 539 | + Servers: map[string]*config.ServerConfig{ |
| 540 | + "slow": {Type: "http", URL: backend.URL}, |
| 541 | + }, |
| 542 | + } |
| 543 | + |
| 544 | + us, err := NewUnified(context.Background(), cfg) |
| 545 | + require.NoError(err) |
| 546 | + defer us.Close() |
| 547 | + |
| 548 | + ctx := context.WithValue(context.Background(), SessionIDContextKey, "timeout-test") |
| 549 | + result, _, callErr := us.callBackendTool(ctx, "slow", "slow_tool", map[string]interface{}{}) |
| 550 | + |
| 551 | + // The call should fail due to context deadline exceeded |
| 552 | + require.Error(callErr, "Tool call should fail due to timeout") |
| 553 | + require.NotNil(result, "Should return a CallToolResult even on timeout") |
| 554 | + require.True(result.IsError, "Result should be marked as error") |
| 555 | + t.Logf("Tool call correctly timed out: %v", callErr) |
| 556 | +} |
0 commit comments