Skip to content

Commit b5a282a

Browse files
authored
fix(adk): return modified message from AfterChatModel #717 (#792)
1 parent e685d86 commit b5a282a

File tree

3 files changed

+326
-71
lines changed

3 files changed

+326
-71
lines changed

adk/chatmodel_retry_test.go

Lines changed: 126 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -175,92 +175,148 @@ func (m *streamErrorModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallin
175175
}
176176

177177
func TestChatModelAgentRetry_StreamError(t *testing.T) {
178-
tests := []struct {
179-
name string
180-
withTool bool
181-
}{
182-
{"NoTools", false},
183-
{"WithTools", true},
184-
}
185-
186-
for _, tt := range tests {
187-
t.Run(tt.name, func(t *testing.T) {
188-
ctx := context.Background()
178+
t.Run("WithTools", func(t *testing.T) {
179+
ctx := context.Background()
189180

190-
m := &streamErrorModel{
191-
failAtChunk: 2,
192-
maxFailures: 2,
193-
returnTool: false,
194-
}
181+
m := &streamErrorModel{
182+
failAtChunk: 2,
183+
maxFailures: 2,
184+
returnTool: false,
185+
}
195186

196-
config := &ChatModelAgentConfig{
197-
Name: "RetryTestAgent",
198-
Description: "Test agent for retry functionality",
199-
Instruction: "You are a helpful assistant.",
200-
Model: m,
201-
ModelRetryConfig: &ModelRetryConfig{
202-
MaxRetries: 3,
203-
IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) },
187+
config := &ChatModelAgentConfig{
188+
Name: "RetryTestAgent",
189+
Description: "Test agent for retry functionality",
190+
Instruction: "You are a helpful assistant.",
191+
Model: m,
192+
ModelRetryConfig: &ModelRetryConfig{
193+
MaxRetries: 3,
194+
IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) },
195+
},
196+
ToolsConfig: ToolsConfig{
197+
ToolsNodeConfig: compose.ToolsNodeConfig{
198+
Tools: []tool.BaseTool{&fakeToolForTest{tarCount: 0}},
204199
},
200+
},
201+
}
202+
203+
agent, err := NewChatModelAgent(ctx, config)
204+
assert.NoError(t, err)
205+
206+
input := &AgentInput{
207+
Messages: []Message{schema.UserMessage("Hello")},
208+
EnableStreaming: true,
209+
}
210+
iterator := agent.Run(ctx, input)
211+
212+
var events []*AgentEvent
213+
for {
214+
event, ok := iterator.Next()
215+
if !ok {
216+
break
205217
}
218+
events = append(events, event)
219+
}
206220

207-
if tt.withTool {
208-
config.ToolsConfig = ToolsConfig{
209-
ToolsNodeConfig: compose.ToolsNodeConfig{
210-
Tools: []tool.BaseTool{&fakeToolForTest{tarCount: 0}},
211-
},
221+
assert.Equal(t, 3, len(events))
222+
223+
var streamErrEventCount int
224+
var errs []error
225+
for i, event := range events {
226+
if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming {
227+
sr := event.Output.MessageOutput.MessageStream
228+
for {
229+
msg, err := sr.Recv()
230+
if err == io.EOF {
231+
break
232+
}
233+
if err != nil {
234+
streamErrEventCount++
235+
errs = append(errs, err)
236+
t.Logf("event %d: err: %v", i, err)
237+
break
238+
}
239+
t.Logf("event %d: %v", i, msg.Content)
212240
}
213241
}
242+
}
214243

215-
agent, err := NewChatModelAgent(ctx, config)
216-
assert.NoError(t, err)
244+
assert.Equal(t, 2, streamErrEventCount)
245+
assert.Equal(t, 2, len(errs))
246+
var willRetryErr *WillRetryError
247+
assert.True(t, errors.As(errs[0], &willRetryErr))
248+
assert.True(t, errors.As(errs[1], &willRetryErr))
249+
assert.Equal(t, int32(3), atomic.LoadInt32(&m.callCount))
250+
})
217251

218-
input := &AgentInput{
219-
Messages: []Message{schema.UserMessage("Hello")},
220-
EnableStreaming: true,
221-
}
222-
iterator := agent.Run(ctx, input)
252+
t.Run("NoTools", func(t *testing.T) {
253+
ctx := context.Background()
223254

224-
var events []*AgentEvent
225-
for {
226-
event, ok := iterator.Next()
227-
if !ok {
228-
break
229-
}
230-
events = append(events, event)
255+
m := &streamErrorModel{
256+
failAtChunk: 2,
257+
maxFailures: 2,
258+
returnTool: false,
259+
}
260+
261+
config := &ChatModelAgentConfig{
262+
Name: "RetryTestAgent",
263+
Description: "Test agent for retry functionality",
264+
Instruction: "You are a helpful assistant.",
265+
Model: m,
266+
ModelRetryConfig: &ModelRetryConfig{
267+
MaxRetries: 3,
268+
IsRetryAble: func(ctx context.Context, err error) bool { return errors.Is(err, errRetryAble) },
269+
},
270+
}
271+
272+
agent, err := NewChatModelAgent(ctx, config)
273+
assert.NoError(t, err)
274+
275+
input := &AgentInput{
276+
Messages: []Message{schema.UserMessage("Hello")},
277+
EnableStreaming: true,
278+
}
279+
iterator := agent.Run(ctx, input)
280+
281+
var events []*AgentEvent
282+
for {
283+
event, ok := iterator.Next()
284+
if !ok {
285+
break
231286
}
287+
events = append(events, event)
288+
}
232289

233-
assert.Equal(t, 3, len(events))
234-
235-
var streamErrEventCount int
236-
var errs []error
237-
for i, event := range events {
238-
if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming {
239-
sr := event.Output.MessageOutput.MessageStream
240-
for {
241-
msg, err := sr.Recv()
242-
if err == io.EOF {
243-
break
244-
}
245-
if err != nil {
246-
streamErrEventCount++
247-
errs = append(errs, err)
248-
t.Logf("event %d: err: %v", i, err)
249-
break
250-
}
251-
t.Logf("event %d: %v", i, msg.Content)
290+
assert.Equal(t, 3, len(events))
291+
292+
var streamErrEventCount int
293+
var errs []error
294+
for i, event := range events {
295+
if event.Output != nil && event.Output.MessageOutput != nil && event.Output.MessageOutput.IsStreaming {
296+
sr := event.Output.MessageOutput.MessageStream
297+
for {
298+
msg, err := sr.Recv()
299+
if err == io.EOF {
300+
break
252301
}
302+
if err != nil {
303+
streamErrEventCount++
304+
errs = append(errs, err)
305+
t.Logf("event %d: err: %v", i, err)
306+
break
307+
}
308+
t.Logf("event %d: %v", i, msg.Content)
253309
}
254310
}
311+
}
255312

256-
assert.Equal(t, 2, streamErrEventCount)
257-
assert.Equal(t, 2, len(errs))
258-
var willRetryErr *WillRetryError
259-
assert.True(t, errors.As(errs[0], &willRetryErr))
260-
assert.True(t, errors.As(errs[1], &willRetryErr))
261-
assert.Equal(t, int32(3), atomic.LoadInt32(&m.callCount))
262-
})
263-
}
313+
assert.Equal(t, 2, streamErrEventCount)
314+
assert.Equal(t, 2, len(errs))
315+
var willRetryErr *WillRetryError
316+
assert.True(t, errors.As(errs[0], &willRetryErr))
317+
assert.True(t, errors.As(errs[1], &willRetryErr))
318+
assert.Equal(t, int32(3), atomic.LoadInt32(&m.callCount))
319+
})
264320
}
265321

266322
func TestChatModelAgentRetry_WithTools_DirectError_Generate(t *testing.T) {

0 commit comments

Comments
 (0)