@@ -121,17 +121,39 @@ func (p *Provider) Stream(ctx *context.Context, messages []context.Message, opti
121121 maxValidationRetries := 3
122122 var lastErr error
123123
124+ // Get Go context for cancellation support
125+ goCtx := ctx .Context
126+ if goCtx == nil {
127+ goCtx = gocontext .Background ()
128+ }
129+
124130 // Make a copy of messages to avoid modifying the original
125131 currentMessages := make ([]context.Message , len (messages ))
126132 copy (currentMessages , messages )
127133
128134 // Outer loop: handle network/API errors with exponential backoff
129135 for attempt := 0 ; attempt < maxRetries ; attempt ++ {
136+ // Check if context is cancelled before retry
137+ select {
138+ case <- goCtx .Done ():
139+ return nil , fmt .Errorf ("context cancelled: %w" , goCtx .Err ())
140+ default :
141+ }
142+
130143 if attempt > 0 {
131144 // Exponential backoff: 1s, 2s, 4s
132145 backoff := time .Duration (1 << uint (attempt - 1 )) * time .Second
133146 log .Warn ("OpenAI stream request failed, retrying in %v (attempt %d/%d): %v" , backoff , attempt + 1 , maxRetries , lastErr )
134- time .Sleep (backoff )
147+
148+ // Sleep with context cancellation support
149+ timer := time .NewTimer (backoff )
150+ select {
151+ case <- timer .C :
152+ // Continue to retry
153+ case <- goCtx .Done ():
154+ timer .Stop ()
155+ return nil , fmt .Errorf ("context cancelled during backoff: %w" , goCtx .Err ())
156+ }
135157 }
136158
137159 response , err := p .streamWithRetry (ctx , currentMessages , options , handler )
@@ -188,6 +210,19 @@ func (p *Provider) streamWithRetry(ctx *context.Context, messages []context.Mess
188210 streamStartTime := time .Now ()
189211 requestID := fmt .Sprintf ("req_%d" , streamStartTime .UnixNano ())
190212
213+ // Get Go context for cancellation support
214+ goCtx := ctx .Context
215+ if goCtx == nil {
216+ goCtx = gocontext .Background ()
217+ }
218+
219+ // Check if context is already cancelled
220+ select {
221+ case <- goCtx .Done ():
222+ return nil , fmt .Errorf ("context cancelled before stream start: %w" , goCtx .Err ())
223+ default :
224+ }
225+
191226 // Send stream_start event
192227 if handler != nil {
193228 model , _ := p .GetModel ()
@@ -256,6 +291,14 @@ func (p *Provider) streamWithRetry(ctx *context.Context, messages []context.Mess
256291
257292 // Stream handler
258293 streamHandler := func (data []byte ) int {
294+ // Check for context cancellation
295+ select {
296+ case <- goCtx .Done ():
297+ log .Warn ("Stream cancelled by context" )
298+ return http .HandlerReturnBreak
299+ default :
300+ }
301+
259302 if len (data ) == 0 {
260303 return http .HandlerReturnOk
261304 }
@@ -401,13 +444,30 @@ func (p *Provider) streamWithRetry(ctx *context.Context, messages []context.Mess
401444 return http .HandlerReturnOk
402445 }
403446
404- // Make streaming request
405- goCtx := ctx .Context
406- if goCtx == nil {
407- goCtx = gocontext .Background ()
447+ // Make streaming request (goCtx already set at function start)
448+ err = req .Stream (goCtx , "POST" , requestBody , streamHandler )
449+
450+ // Check if error is due to context cancellation
451+ if err != nil && goCtx .Err () != nil {
452+ // End current group if active
453+ groupTracker .endGroup (handler )
454+
455+ // Send stream_end with cancellation status
456+ if handler != nil {
457+ endData := & context.StreamEndData {
458+ RequestID : requestID ,
459+ Timestamp : time .Now ().UnixMilli (),
460+ DurationMs : time .Since (streamStartTime ).Milliseconds (),
461+ Status : "cancelled" ,
462+ Error : goCtx .Err ().Error (),
463+ }
464+ if endJSON , err := jsoniter .Marshal (endData ); err == nil {
465+ handler (context .ChunkStreamEnd , endJSON )
466+ }
467+ }
468+ return nil , fmt .Errorf ("stream cancelled: %w" , goCtx .Err ())
408469 }
409470
410- err = req .Stream (goCtx , "POST" , requestBody , streamHandler )
411471 if err != nil {
412472 // End current group if active
413473 groupTracker .endGroup (handler )
@@ -540,17 +600,39 @@ func (p *Provider) Post(ctx *context.Context, messages []context.Message, option
540600 maxValidationRetries := 3
541601 var lastErr error
542602
603+ // Get Go context for cancellation support
604+ goCtx := ctx .Context
605+ if goCtx == nil {
606+ goCtx = gocontext .Background ()
607+ }
608+
543609 // Make a copy of messages to avoid modifying the original
544610 currentMessages := make ([]context.Message , len (messages ))
545611 copy (currentMessages , messages )
546612
547613 // Outer loop: handle network/API errors with exponential backoff
548614 for attempt := 0 ; attempt < maxRetries ; attempt ++ {
615+ // Check if context is cancelled before retry
616+ select {
617+ case <- goCtx .Done ():
618+ return nil , fmt .Errorf ("context cancelled: %w" , goCtx .Err ())
619+ default :
620+ }
621+
549622 if attempt > 0 {
550623 // Exponential backoff
551624 backoff := time .Duration (1 << uint (attempt - 1 )) * time .Second
552625 log .Warn ("OpenAI post request failed, retrying in %v (attempt %d/%d): %v" , backoff , attempt + 1 , maxRetries , lastErr )
553- time .Sleep (backoff )
626+
627+ // Sleep with context cancellation support
628+ timer := time .NewTimer (backoff )
629+ select {
630+ case <- timer .C :
631+ // Continue to retry
632+ case <- goCtx .Done ():
633+ timer .Stop ()
634+ return nil , fmt .Errorf ("context cancelled during backoff: %w" , goCtx .Err ())
635+ }
554636 }
555637
556638 response , err := p .postWithRetry (ctx , currentMessages , options )
0 commit comments