Skip to content

Commit 4a4f0c1

Browse files
milestone: 0.7
Change-Id: I3e1e2969b7aa08b53df6b0acb1ef7eda39a6d70b
1 parent bc53f35 commit 4a4f0c1

28 files changed

+7480
-1684
lines changed

adk/agent_tool.go

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -91,40 +91,19 @@ func (at *agentTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
9191
}
9292

9393
func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
94-
var intData *agentToolInterruptInfo
95-
var bResume bool
96-
err := compose.ProcessState(ctx, func(ctx context.Context, s *State) error {
97-
toolCallID := compose.GetToolCallID(ctx)
98-
intData, bResume = s.AgentToolInterruptData[toolCallID]
99-
if bResume {
100-
delete(s.AgentToolInterruptData, toolCallID)
101-
}
102-
return nil
103-
})
104-
if err != nil {
105-
// cannot resume
106-
bResume = false
107-
}
108-
10994
var ms *mockStore
11095
var iter *AsyncIterator[*AgentEvent]
111-
if bResume {
112-
ms = newResumeStore(intData.Data)
96+
var err error
11397

114-
iter, err = newInvokableAgentToolRunner(at.agent, ms).Resume(ctx, mockCheckPointID, getOptionsByAgentName(at.agent.Name(ctx), opts)...)
115-
if err != nil {
116-
return "", err
117-
}
118-
} else {
98+
wasInterrupted, hasState, state := compose.GetInterruptState[[]byte](ctx)
99+
if !wasInterrupted {
119100
ms = newEmptyStore()
120101
var input []Message
121102
if at.fullChatHistoryAsInput {
122-
history, err := getReactChatHistory(ctx, at.agent.Name(ctx))
103+
input, err = getReactChatHistory(ctx, at.agent.Name(ctx))
123104
if err != nil {
124105
return "", err
125106
}
126-
127-
input = history
128107
} else {
129108
if at.inputSchema == nil {
130109
// default input schema
@@ -144,7 +123,20 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o
144123
}
145124
}
146125

147-
iter = newInvokableAgentToolRunner(at.agent, ms).Run(ctx, input, append(getOptionsByAgentName(at.agent.Name(ctx), opts), WithCheckPointID(mockCheckPointID))...)
126+
iter = newInvokableAgentToolRunner(at.agent, ms).Run(ctx, input,
127+
append(getOptionsByAgentName(at.agent.Name(ctx), opts), WithCheckPointID(mockCheckPointID))...)
128+
} else {
129+
if !hasState {
130+
return "", fmt.Errorf("agent tool '%s' interrupt has happened, but cannot find interrupt state", at.agent.Name(ctx))
131+
}
132+
133+
ms = newResumeStore(state)
134+
135+
iter, err = newInvokableAgentToolRunner(at.agent, ms).
136+
Resume(ctx, mockCheckPointID, getOptionsByAgentName(at.agent.Name(ctx), opts)...)
137+
if err != nil {
138+
return "", err
139+
}
148140
}
149141

150142
var lastEvent *AgentEvent
@@ -169,17 +161,9 @@ func (at *agentTool) InvokableRun(ctx context.Context, argumentsInJSON string, o
169161
if !existed {
170162
return "", fmt.Errorf("interrupt has happened, but cannot find interrupt info")
171163
}
172-
err = compose.ProcessState(ctx, func(ctx context.Context, st *State) error {
173-
st.AgentToolInterruptData[compose.GetToolCallID(ctx)] = &agentToolInterruptInfo{
174-
LastEvent: lastEvent,
175-
Data: data,
176-
}
177-
return nil
178-
})
179-
if err != nil {
180-
return "", fmt.Errorf("failed to save agent tool checkpoint to state: %w", err)
181-
}
182-
return "", compose.InterruptAndRerun
164+
165+
return "", compose.CompositeInterrupt(ctx, "agent tool interrupt", data,
166+
lastEvent.Action.internalInterrupted)
183167
}
184168

185169
if lastEvent == nil {

adk/agent_tool_test.go

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,278 @@ func TestGetReactHistory(t *testing.T) {
217217
schema.UserMessage("For context: [MyAgent] `transfer_to_agent` tool returned result: successfully transferred to agent [DestAgentName]."),
218218
}, result)
219219
}
220+
221+
// mockAgentWithInputCapture implements the Agent interface for testing and captures the input it receives
222+
type mockAgentWithInputCapture struct {
223+
name string
224+
description string
225+
capturedInput []Message
226+
responses []*AgentEvent
227+
}
228+
229+
func (a *mockAgentWithInputCapture) Name(_ context.Context) string {
230+
return a.name
231+
}
232+
233+
func (a *mockAgentWithInputCapture) Description(_ context.Context) string {
234+
return a.description
235+
}
236+
237+
func (a *mockAgentWithInputCapture) Run(_ context.Context, input *AgentInput, _ ...AgentRunOption) *AsyncIterator[*AgentEvent] {
238+
a.capturedInput = input.Messages
239+
240+
iterator, generator := NewAsyncIteratorPair[*AgentEvent]()
241+
242+
go func() {
243+
defer generator.Close()
244+
245+
for _, event := range a.responses {
246+
generator.Send(event)
247+
248+
// If the event has an Exit action, stop sending events
249+
if event.Action != nil && event.Action.Exit {
250+
break
251+
}
252+
}
253+
}()
254+
255+
return iterator
256+
}
257+
258+
func newMockAgentWithInputCapture(name, description string, responses []*AgentEvent) *mockAgentWithInputCapture {
259+
return &mockAgentWithInputCapture{
260+
name: name,
261+
description: description,
262+
responses: responses,
263+
}
264+
}
265+
266+
func TestAgentToolWithOptions(t *testing.T) {
267+
// Test Case 1: WithFullChatHistoryAsInput
268+
t.Run("WithFullChatHistoryAsInput", func(t *testing.T) {
269+
ctx := context.Background()
270+
271+
// 1. Set up a mock agent that will capture the input it receives
272+
mockAgent := newMockAgentWithInputCapture("test-agent", "a test agent", []*AgentEvent{
273+
{
274+
AgentName: "test-agent",
275+
Output: &AgentOutput{
276+
MessageOutput: &MessageVariant{
277+
IsStreaming: false,
278+
Message: schema.AssistantMessage("done", nil),
279+
Role: schema.Assistant,
280+
},
281+
},
282+
},
283+
})
284+
285+
// 2. Create an agentTool with the option
286+
agentTool := NewAgentTool(ctx, mockAgent, WithFullChatHistoryAsInput())
287+
288+
// 3. Set up a context with a chat history using a graph
289+
history := []Message{
290+
schema.UserMessage("first user message"),
291+
schema.AssistantMessage("first assistant response", nil),
292+
}
293+
294+
g := compose.NewGraph[string, string](compose.WithGenLocalState(func(ctx context.Context) (state *State) {
295+
return &State{
296+
AgentName: "react-agent",
297+
Messages: append(history, schema.AssistantMessage("tool call msg", nil)),
298+
}
299+
}))
300+
301+
assert.NoError(t, g.AddLambdaNode("1", compose.InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
302+
// Run the tool within the graph context that has the state
303+
_, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"request":"some ignored input"}`)
304+
return "done", err
305+
})))
306+
assert.NoError(t, g.AddEdge(compose.START, "1"))
307+
assert.NoError(t, g.AddEdge("1", compose.END))
308+
309+
runner, err := g.Compile(ctx)
310+
assert.NoError(t, err)
311+
312+
// 4. Run the graph which will execute the tool with the state
313+
_, err = runner.Invoke(ctx, "")
314+
assert.NoError(t, err)
315+
316+
// 5. Assert that the agent received the full history
317+
// The agent should receive: history (minus last assistant message) + transfer messages
318+
assert.Len(t, mockAgent.capturedInput, 4) // 2 from history + 2 transfer messages
319+
assert.Equal(t, "first user message", mockAgent.capturedInput[0].Content)
320+
assert.Equal(t, "For context: [react-agent] said: first assistant response.", mockAgent.capturedInput[1].Content)
321+
assert.Equal(t, "For context: [react-agent] called tool: `transfer_to_agent` with arguments: test-agent.", mockAgent.capturedInput[2].Content)
322+
assert.Equal(t, "For context: [react-agent] `transfer_to_agent` tool returned result: successfully transferred to agent [test-agent].", mockAgent.capturedInput[3].Content)
323+
})
324+
325+
// Test Case 2: WithAgentInputSchema
326+
t.Run("WithAgentInputSchema", func(t *testing.T) {
327+
ctx := context.Background()
328+
329+
// 1. Define a custom schema
330+
customSchema := schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
331+
"custom_arg": {
332+
Desc: "a custom argument",
333+
Required: true,
334+
Type: schema.String,
335+
},
336+
})
337+
338+
// 2. Set up a mock agent to capture input
339+
mockAgent := newMockAgentWithInputCapture("schema-agent", "agent with custom schema", []*AgentEvent{
340+
{
341+
AgentName: "schema-agent",
342+
Output: &AgentOutput{
343+
MessageOutput: &MessageVariant{
344+
IsStreaming: false,
345+
Message: schema.AssistantMessage("schema processed", nil),
346+
Role: schema.Assistant,
347+
},
348+
},
349+
},
350+
})
351+
352+
// 3. Create agentTool with the custom schema option
353+
agentTool := NewAgentTool(ctx, mockAgent, WithAgentInputSchema(customSchema))
354+
355+
// 4. Verify the Info() method returns the custom schema
356+
info, err := agentTool.Info(ctx)
357+
assert.NoError(t, err)
358+
assert.Equal(t, customSchema, info.ParamsOneOf)
359+
360+
// 5. Run the tool with arguments matching the custom schema
361+
_, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"custom_arg":"hello world"}`)
362+
assert.NoError(t, err)
363+
364+
// 6. Assert that the agent received the correctly parsed argument
365+
// With custom schema, the agent should receive the raw JSON as input
366+
assert.Len(t, mockAgent.capturedInput, 1)
367+
assert.Equal(t, `{"custom_arg":"hello world"}`, mockAgent.capturedInput[0].Content)
368+
})
369+
370+
// Test Case 3: WithAgentInputSchema with complex schema
371+
t.Run("WithAgentInputSchema_ComplexSchema", func(t *testing.T) {
372+
ctx := context.Background()
373+
374+
// 1. Define a complex custom schema with multiple parameters
375+
complexSchema := schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
376+
"name": {
377+
Desc: "user name",
378+
Required: true,
379+
Type: schema.String,
380+
},
381+
"age": {
382+
Desc: "user age",
383+
Required: false,
384+
Type: schema.Integer,
385+
},
386+
"active": {
387+
Desc: "user status",
388+
Required: false,
389+
Type: schema.Boolean,
390+
},
391+
})
392+
393+
// 2. Set up a mock agent
394+
mockAgent := newMockAgentWithInputCapture("complex-agent", "agent with complex schema", []*AgentEvent{
395+
{
396+
AgentName: "complex-agent",
397+
Output: &AgentOutput{
398+
MessageOutput: &MessageVariant{
399+
IsStreaming: false,
400+
Message: schema.AssistantMessage("complex processed", nil),
401+
Role: schema.Assistant,
402+
},
403+
},
404+
},
405+
})
406+
407+
// 3. Create agentTool with the complex schema option
408+
agentTool := NewAgentTool(ctx, mockAgent, WithAgentInputSchema(complexSchema))
409+
410+
// 4. Verify the Info() method returns the complex schema
411+
info, err := agentTool.Info(ctx)
412+
assert.NoError(t, err)
413+
assert.Equal(t, complexSchema, info.ParamsOneOf)
414+
415+
// 5. Run the tool with complex arguments
416+
_, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"name":"John","age":30,"active":true}`)
417+
assert.NoError(t, err)
418+
419+
// 6. Assert that the agent received the complex JSON
420+
assert.Len(t, mockAgent.capturedInput, 1)
421+
assert.Equal(t, `{"name":"John","age":30,"active":true}`, mockAgent.capturedInput[0].Content)
422+
})
423+
424+
// Test Case 4: Both options together
425+
t.Run("BothOptionsTogether", func(t *testing.T) {
426+
ctx := context.Background()
427+
428+
// 1. Define a custom schema
429+
customSchema := schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
430+
"query": {
431+
Desc: "search query",
432+
Required: true,
433+
Type: schema.String,
434+
},
435+
})
436+
437+
// 2. Set up a mock agent
438+
mockAgent := newMockAgentWithInputCapture("combined-agent", "agent with both options", []*AgentEvent{
439+
{
440+
AgentName: "combined-agent",
441+
Output: &AgentOutput{
442+
MessageOutput: &MessageVariant{
443+
IsStreaming: false,
444+
Message: schema.AssistantMessage("combined processed", nil),
445+
Role: schema.Assistant,
446+
},
447+
},
448+
},
449+
})
450+
451+
// 3. Create agentTool with both options
452+
agentTool := NewAgentTool(ctx, mockAgent, WithAgentInputSchema(customSchema), WithFullChatHistoryAsInput())
453+
454+
// 4. Set up a context with chat history using a graph
455+
history := []Message{
456+
schema.UserMessage("previous conversation"),
457+
schema.AssistantMessage("previous response", nil),
458+
}
459+
460+
g := compose.NewGraph[string, string](compose.WithGenLocalState(func(ctx context.Context) (state *State) {
461+
return &State{
462+
AgentName: "react-agent",
463+
Messages: append(history, schema.AssistantMessage("tool call", nil)),
464+
}
465+
}))
466+
467+
assert.NoError(t, g.AddLambdaNode("1", compose.InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
468+
// Run the tool within the graph context that has the state
469+
_, err = agentTool.(tool.InvokableTool).InvokableRun(ctx, `{"query":"current query"}`)
470+
return "done", err
471+
})))
472+
assert.NoError(t, g.AddEdge(compose.START, "1"))
473+
assert.NoError(t, g.AddEdge("1", compose.END))
474+
475+
runner, err := g.Compile(ctx)
476+
assert.NoError(t, err)
477+
478+
// 5. Run the graph which will execute the tool with the state
479+
_, err = runner.Invoke(ctx, "")
480+
assert.NoError(t, err)
481+
482+
// 6. Verify both options work together
483+
info, err := agentTool.Info(ctx)
484+
assert.NoError(t, err)
485+
assert.Equal(t, customSchema, info.ParamsOneOf)
486+
487+
// The agent should receive full history + the custom query
488+
assert.Len(t, mockAgent.capturedInput, 4) // 2 history + 2 transfer messages
489+
assert.Equal(t, "previous conversation", mockAgent.capturedInput[0].Content)
490+
assert.Equal(t, "For context: [react-agent] said: previous response.", mockAgent.capturedInput[1].Content)
491+
assert.Equal(t, "For context: [react-agent] called tool: `transfer_to_agent` with arguments: combined-agent.", mockAgent.capturedInput[2].Content)
492+
assert.Equal(t, "For context: [react-agent] `transfer_to_agent` tool returned result: successfully transferred to agent [combined-agent].", mockAgent.capturedInput[3].Content)
493+
})
494+
}

0 commit comments

Comments
 (0)