Skip to content

Commit 8636b0c

Browse files
committed
Enhance context cancellation support in OpenAI provider and tests
- Added context cancellation handling in the OpenAI provider's Stream and Post methods, allowing for graceful termination of operations when the context is cancelled. - Implemented a new test, TestOpenAIStreamContextCancellation, to validate that streaming respects context cancellation, ensuring proper error handling and event emission during cancellation scenarios. - Refactored context usage in existing methods to improve consistency and reliability in handling context across streaming operations.
1 parent da53917 commit 8636b0c

File tree

2 files changed

+174
-9
lines changed

2 files changed

+174
-9
lines changed

agent/llm/providers/openai/openai.go

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,39 @@ func (p *Provider) Stream(ctx *context.Context, messages []context.Message, opti
121121
maxValidationRetries := 3
122122
var lastErr error
123123

124+
// Get Go context for cancellation support
125+
goCtx := ctx.Context
126+
if goCtx == nil {
127+
goCtx = gocontext.Background()
128+
}
129+
124130
// Make a copy of messages to avoid modifying the original
125131
currentMessages := make([]context.Message, len(messages))
126132
copy(currentMessages, messages)
127133

128134
// Outer loop: handle network/API errors with exponential backoff
129135
for attempt := 0; attempt < maxRetries; attempt++ {
136+
// Check if context is cancelled before retry
137+
select {
138+
case <-goCtx.Done():
139+
return nil, fmt.Errorf("context cancelled: %w", goCtx.Err())
140+
default:
141+
}
142+
130143
if attempt > 0 {
131144
// Exponential backoff: 1s, 2s, 4s
132145
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
133146
log.Warn("OpenAI stream request failed, retrying in %v (attempt %d/%d): %v", backoff, attempt+1, maxRetries, lastErr)
134-
time.Sleep(backoff)
147+
148+
// Sleep with context cancellation support
149+
timer := time.NewTimer(backoff)
150+
select {
151+
case <-timer.C:
152+
// Continue to retry
153+
case <-goCtx.Done():
154+
timer.Stop()
155+
return nil, fmt.Errorf("context cancelled during backoff: %w", goCtx.Err())
156+
}
135157
}
136158

137159
response, err := p.streamWithRetry(ctx, currentMessages, options, handler)
@@ -188,6 +210,19 @@ func (p *Provider) streamWithRetry(ctx *context.Context, messages []context.Mess
188210
streamStartTime := time.Now()
189211
requestID := fmt.Sprintf("req_%d", streamStartTime.UnixNano())
190212

213+
// Get Go context for cancellation support
214+
goCtx := ctx.Context
215+
if goCtx == nil {
216+
goCtx = gocontext.Background()
217+
}
218+
219+
// Check if context is already cancelled
220+
select {
221+
case <-goCtx.Done():
222+
return nil, fmt.Errorf("context cancelled before stream start: %w", goCtx.Err())
223+
default:
224+
}
225+
191226
// Send stream_start event
192227
if handler != nil {
193228
model, _ := p.GetModel()
@@ -256,6 +291,14 @@ func (p *Provider) streamWithRetry(ctx *context.Context, messages []context.Mess
256291

257292
// Stream handler
258293
streamHandler := func(data []byte) int {
294+
// Check for context cancellation
295+
select {
296+
case <-goCtx.Done():
297+
log.Warn("Stream cancelled by context")
298+
return http.HandlerReturnBreak
299+
default:
300+
}
301+
259302
if len(data) == 0 {
260303
return http.HandlerReturnOk
261304
}
@@ -401,13 +444,30 @@ func (p *Provider) streamWithRetry(ctx *context.Context, messages []context.Mess
401444
return http.HandlerReturnOk
402445
}
403446

404-
// Make streaming request
405-
goCtx := ctx.Context
406-
if goCtx == nil {
407-
goCtx = gocontext.Background()
447+
// Make streaming request (goCtx already set at function start)
448+
err = req.Stream(goCtx, "POST", requestBody, streamHandler)
449+
450+
// Check if error is due to context cancellation
451+
if err != nil && goCtx.Err() != nil {
452+
// End current group if active
453+
groupTracker.endGroup(handler)
454+
455+
// Send stream_end with cancellation status
456+
if handler != nil {
457+
endData := &context.StreamEndData{
458+
RequestID: requestID,
459+
Timestamp: time.Now().UnixMilli(),
460+
DurationMs: time.Since(streamStartTime).Milliseconds(),
461+
Status: "cancelled",
462+
Error: goCtx.Err().Error(),
463+
}
464+
if endJSON, err := jsoniter.Marshal(endData); err == nil {
465+
handler(context.ChunkStreamEnd, endJSON)
466+
}
467+
}
468+
return nil, fmt.Errorf("stream cancelled: %w", goCtx.Err())
408469
}
409470

410-
err = req.Stream(goCtx, "POST", requestBody, streamHandler)
411471
if err != nil {
412472
// End current group if active
413473
groupTracker.endGroup(handler)
@@ -540,17 +600,39 @@ func (p *Provider) Post(ctx *context.Context, messages []context.Message, option
540600
maxValidationRetries := 3
541601
var lastErr error
542602

603+
// Get Go context for cancellation support
604+
goCtx := ctx.Context
605+
if goCtx == nil {
606+
goCtx = gocontext.Background()
607+
}
608+
543609
// Make a copy of messages to avoid modifying the original
544610
currentMessages := make([]context.Message, len(messages))
545611
copy(currentMessages, messages)
546612

547613
// Outer loop: handle network/API errors with exponential backoff
548614
for attempt := 0; attempt < maxRetries; attempt++ {
615+
// Check if context is cancelled before retry
616+
select {
617+
case <-goCtx.Done():
618+
return nil, fmt.Errorf("context cancelled: %w", goCtx.Err())
619+
default:
620+
}
621+
549622
if attempt > 0 {
550623
// Exponential backoff
551624
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
552625
log.Warn("OpenAI post request failed, retrying in %v (attempt %d/%d): %v", backoff, attempt+1, maxRetries, lastErr)
553-
time.Sleep(backoff)
626+
627+
// Sleep with context cancellation support
628+
timer := time.NewTimer(backoff)
629+
select {
630+
case <-timer.C:
631+
// Continue to retry
632+
case <-goCtx.Done():
633+
timer.Stop()
634+
return nil, fmt.Errorf("context cancelled during backoff: %w", goCtx.Err())
635+
}
554636
}
555637

556638
response, err := p.postWithRetry(ctx, currentMessages, options)

agent/llm/providers/openai/openai_test.go

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package openai_test
22

33
import (
4-
stdContext "context"
4+
gocontext "context"
55
"encoding/json"
66
"strings"
77
"testing"
8+
"time"
89

910
"github.com/yaoapp/gou/connector"
1011
"github.com/yaoapp/gou/plan"
@@ -1388,6 +1389,88 @@ func TestOpenAIStreamLifecycleEvents(t *testing.T) {
13881389
t.Log("Lifecycle events test completed successfully")
13891390
}
13901391

1392+
// TestOpenAIStreamContextCancellation tests that stream respects context cancellation
1393+
func TestOpenAIStreamContextCancellation(t *testing.T) {
1394+
test.Prepare(t, config.Conf)
1395+
defer test.Clean()
1396+
1397+
conn, err := connector.Select("openai.gpt-4o")
1398+
if err != nil {
1399+
t.Fatalf("Failed to select connector: %v", err)
1400+
}
1401+
1402+
trueVal := true
1403+
options := &context.CompletionOptions{
1404+
Capabilities: &context.ModelCapabilities{
1405+
Streaming: &trueVal,
1406+
ToolCalls: &trueVal,
1407+
},
1408+
}
1409+
1410+
llmInstance, err := llm.New(conn, options)
1411+
if err != nil {
1412+
t.Fatalf("Failed to create LLM instance: %v", err)
1413+
}
1414+
1415+
messages := []context.Message{
1416+
{
1417+
Role: context.RoleUser,
1418+
Content: "Write a very long essay about the history of computing", // Long task
1419+
},
1420+
}
1421+
1422+
// Create a context with a very short timeout
1423+
ctx := newTestContext("test-cancel", "openai.gpt-4o")
1424+
goCtx, cancel := gocontext.WithTimeout(gocontext.Background(), 100*time.Millisecond)
1425+
defer cancel()
1426+
ctx.Context = goCtx
1427+
1428+
var receivedChunks int
1429+
var receivedStreamEnd bool
1430+
1431+
handler := func(chunkType context.StreamChunkType, data []byte) int {
1432+
if chunkType == context.ChunkText || chunkType == context.ChunkToolCall {
1433+
receivedChunks++
1434+
}
1435+
if chunkType == context.ChunkStreamEnd {
1436+
receivedStreamEnd = true
1437+
var endData context.StreamEndData
1438+
if err := json.Unmarshal(data, &endData); err == nil {
1439+
t.Logf("stream_end status: %s, error: %s", endData.Status, endData.Error)
1440+
}
1441+
}
1442+
return 0
1443+
}
1444+
1445+
response, err := llmInstance.Stream(ctx, messages, options, handler)
1446+
1447+
// Should get an error due to context cancellation
1448+
if err == nil {
1449+
t.Error("Expected error due to context cancellation, but got nil")
1450+
} else {
1451+
t.Logf("✓ Got expected cancellation error: %v", err)
1452+
1453+
// Check if error message indicates cancellation
1454+
errStr := err.Error()
1455+
if !strings.Contains(errStr, "context") && !strings.Contains(errStr, "cancel") {
1456+
t.Errorf("Error should mention context/cancellation: %v", err)
1457+
}
1458+
}
1459+
1460+
// Response should be nil due to cancellation
1461+
if response != nil {
1462+
t.Logf("Warning: Response is not nil despite cancellation (partial response)")
1463+
}
1464+
1465+
// Should have received stream_end event (even for cancellation)
1466+
if !receivedStreamEnd {
1467+
t.Error("Expected stream_end event even for cancelled stream")
1468+
}
1469+
1470+
t.Logf("Received %d chunks before cancellation", receivedChunks)
1471+
t.Log("Context cancellation test completed successfully")
1472+
}
1473+
13911474
// TestOpenAIStreamWithTemperature tests different temperature settings
13921475
func TestOpenAIStreamWithTemperature(t *testing.T) {
13931476
test.Prepare(t, config.Conf)
@@ -1476,7 +1559,7 @@ func TestOpenAIStreamWithTemperature(t *testing.T) {
14761559
// newTestContext creates a real Context for testing OpenAI provider
14771560
func newTestContext(chatID, connectorID string) *context.Context {
14781561
return &context.Context{
1479-
Context: stdContext.Background(),
1562+
Context: gocontext.Background(),
14801563
Space: plan.NewMemorySharedSpace(),
14811564
ChatID: chatID,
14821565
AssistantID: "test-assistant",

0 commit comments

Comments
 (0)