Skip to content

Commit f4b756c

Browse files
committed
feat: agent callbacks for adk
1 parent 452b852 commit f4b756c

File tree

8 files changed

+1774
-247
lines changed

8 files changed

+1774
-247
lines changed

adk/agent_middleware.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package adk
18+
19+
import (
20+
"context"
21+
22+
"github.com/cloudwego/eino/components/tool"
23+
"github.com/cloudwego/eino/compose"
24+
)
25+
26+
// AgentMiddleware provides hooks to customize agent behavior at various stages of execution.
27+
type AgentMiddleware struct {
28+
// Name of the middleware, default empty. This will be used for middleware deduplication.
29+
Name string
30+
31+
// AdditionalInstruction adds supplementary text to the agent's system instruction.
32+
// This instruction is concatenated with the base instruction before each chat model call.
33+
AdditionalInstruction string
34+
35+
// AdditionalTools adds supplementary tools to the agent's available toolset.
36+
// These tools are combined with the tools configured for the agent.
37+
AdditionalTools []tool.BaseTool
38+
39+
// BeforeChatModel is called before each ChatModel invocation, allowing modification of the agent state.
40+
BeforeChatModel func(context.Context, *ChatModelAgentState) error
41+
42+
// AfterChatModel is called after each ChatModel invocation, allowing modification of the agent state.
43+
AfterChatModel func(context.Context, *ChatModelAgentState) error
44+
45+
// WrapToolCall wraps tool calls with custom middleware logic.
46+
// Each middleware contains Invokable and/or Streamable functions for tool calls.
47+
WrapToolCall compose.ToolMiddleware
48+
49+
// BeforeAgent is called before the agent starts executing. It allows modifying the context
50+
// or performing any setup actions before the agent begins processing.
51+
// When an error is returned:
52+
// 1. The framework will immediately return an AsyncIterator containing only this error
53+
// 2. Subsequent BeforeAgent steps in other middlewares will be interrupted
54+
// 3. The OnEvents handlers in previously executed middlewares will be invoked
55+
BeforeAgent func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error)
56+
57+
// OnEvents is called to handle events generated by the agent during execution.
58+
// - iter: The iterator contains the original output from the agent or the processed output from the previous middlewares.
59+
// - gen: The generator is used to send the processed events to the next middleware or directly as output.
60+
// This allows for filtering, transforming, or adding events in the middleware chain.
61+
OnEvents func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent])
62+
}
63+
64+
// AgentMiddlewareChecker is an interface that agents can implement to indicate
65+
// whether they support and enable middleware functionality.
66+
// Agents implementing this interface will execute middlewares internally;
67+
// otherwise, middlewares will be executed outside the agent by Runner.
68+
type AgentMiddlewareChecker interface {
69+
IsAgentMiddlewareEnabled() bool
70+
}
71+
72+
// ChatModelAgentState represents the state of a chat model agent during conversation.
73+
type ChatModelAgentState struct {
74+
// Messages contains all messages in the current conversation session.
75+
Messages []Message
76+
}
77+
78+
type EntranceType string
79+
80+
const (
81+
// EntranceTypeRun indicates the agent is starting a new execution from scratch.
82+
EntranceTypeRun EntranceType = "Run"
83+
// EntranceTypeResume indicates the agent is resuming a previously interrupted execution.
84+
EntranceTypeResume EntranceType = "Resume"
85+
)
86+
87+
// AgentContext contains the context information for an agent's execution.
88+
// It provides access to input data, resume information, and execution options.
89+
type AgentContext struct {
90+
// AgentInput contains the input data for the agent's execution.
91+
AgentInput *AgentInput
92+
// ResumeInfo contains information needed to resume a previously interrupted execution.
93+
ResumeInfo *ResumeInfo
94+
// AgentRunOptions contains options for configuring the agent's execution.
95+
AgentRunOptions []AgentRunOption
96+
97+
// internal properties, read only
98+
agentName string
99+
entrance EntranceType
100+
}
101+
102+
func (a *AgentContext) AgentName() string {
103+
return a.agentName
104+
}
105+
106+
func (a *AgentContext) EntranceType() EntranceType {
107+
return a.entrance
108+
}
109+
110+
// GetGlobalAgentMiddlewares try to get agent middlewares set by RunnerMiddleware
111+
func GetGlobalAgentMiddlewares(ctx context.Context) []AgentMiddleware {
112+
if v, ok := ctx.Value(globalAgentMiddlewareCtxKey{}).([]AgentMiddleware); ok && v != nil {
113+
return v
114+
}
115+
return nil
116+
}
117+
118+
type globalAgentMiddlewareCtxKey struct{}
119+
120+
func isAgentMiddlewareEnabled(a Agent) bool {
121+
if c, ok := a.(AgentMiddlewareChecker); ok && c.IsAgentMiddlewareEnabled() {
122+
return true
123+
}
124+
return false
125+
}
126+
127+
type agentMWHelper struct {
128+
beforeAgentFns []func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error)
129+
onEventsFns []func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent])
130+
}
131+
132+
func (a *agentMWHelper) execBeforeAgents(ctx context.Context, ac *AgentContext) (context.Context, *AsyncIterator[*AgentEvent]) {
133+
var err error
134+
for i, beforeAgent := range a.beforeAgentFns {
135+
if beforeAgent == nil {
136+
continue
137+
}
138+
ctx, err = beforeAgent(ctx, ac)
139+
if err != nil {
140+
iter, gen := NewAsyncIteratorPair[*AgentEvent]()
141+
gen.Send(&AgentEvent{Err: err})
142+
gen.Close()
143+
return ctx, a.execOnEventsFromIndex(ctx, ac, i-1, iter)
144+
}
145+
}
146+
return ctx, nil
147+
}
148+
149+
func (a *agentMWHelper) execOnEvents(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] {
150+
return a.execOnEventsFromIndex(ctx, ac, len(a.onEventsFns)-1, iter)
151+
}
152+
153+
func (a *agentMWHelper) execOnEventsFromIndex(ctx context.Context, ac *AgentContext, fromIdx int, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] {
154+
for idx := fromIdx; idx >= 0; idx-- {
155+
onEvents := a.onEventsFns[idx]
156+
if onEvents == nil {
157+
continue
158+
}
159+
i, g := NewAsyncIteratorPair[*AgentEvent]()
160+
onEvents(ctx, ac, iter, g)
161+
iter = i
162+
}
163+
return iter
164+
}

0 commit comments

Comments
 (0)