|
| 1 | +/* |
| 2 | + * Copyright 2025 CloudWeGo Authors |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package adk |
| 18 | + |
| 19 | +import ( |
| 20 | + "context" |
| 21 | + |
| 22 | + "github.com/cloudwego/eino/components/tool" |
| 23 | + "github.com/cloudwego/eino/compose" |
| 24 | +) |
| 25 | + |
| 26 | +// AgentMiddleware provides hooks to customize agent behavior at various stages of execution. |
| 27 | +type AgentMiddleware struct { |
| 28 | + // Name of the middleware, default empty. This will be used for middleware deduplication. |
| 29 | + Name string |
| 30 | + |
| 31 | + // AdditionalInstruction adds supplementary text to the agent's system instruction. |
| 32 | + // This instruction is concatenated with the base instruction before each chat model call. |
| 33 | + AdditionalInstruction string |
| 34 | + |
| 35 | + // AdditionalTools adds supplementary tools to the agent's available toolset. |
| 36 | + // These tools are combined with the tools configured for the agent. |
| 37 | + AdditionalTools []tool.BaseTool |
| 38 | + |
| 39 | + // BeforeChatModel is called before each ChatModel invocation, allowing modification of the agent state. |
| 40 | + BeforeChatModel func(context.Context, *ChatModelAgentState) error |
| 41 | + |
| 42 | + // AfterChatModel is called after each ChatModel invocation, allowing modification of the agent state. |
| 43 | + AfterChatModel func(context.Context, *ChatModelAgentState) error |
| 44 | + |
| 45 | + // WrapToolCall wraps tool calls with custom middleware logic. |
| 46 | + // Each middleware contains Invokable and/or Streamable functions for tool calls. |
| 47 | + WrapToolCall compose.ToolMiddleware |
| 48 | + |
| 49 | + // BeforeAgent is called before the agent starts executing. It allows modifying the context |
| 50 | + // or performing any setup actions before the agent begins processing. |
| 51 | + // When an error is returned: |
| 52 | + // 1. The framework will immediately return an AsyncIterator containing only this error |
| 53 | + // 2. Subsequent BeforeAgent steps in other middlewares will be interrupted |
| 54 | + // 3. The OnEvents handlers in previously executed middlewares will be invoked |
| 55 | + BeforeAgent func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) |
| 56 | + |
| 57 | + // OnEvents is called to handle events generated by the agent during execution. |
| 58 | + // - iter: The iterator contains the original output from the agent or the processed output from the previous middlewares. |
| 59 | + // - gen: The generator is used to send the processed events to the next middleware or directly as output. |
| 60 | + // This allows for filtering, transforming, or adding events in the middleware chain. |
| 61 | + OnEvents func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) |
| 62 | +} |
| 63 | + |
| 64 | +// AgentMiddlewareChecker is an interface that agents can implement to indicate |
| 65 | +// whether they support and enable middleware functionality. |
| 66 | +// Agents implementing this interface will execute middlewares internally; |
| 67 | +// otherwise, middlewares will be executed outside the agent by Runner. |
| 68 | +type AgentMiddlewareChecker interface { |
| 69 | + IsAgentMiddlewareEnabled() bool |
| 70 | +} |
| 71 | + |
| 72 | +// ChatModelAgentState represents the state of a chat model agent during conversation. |
| 73 | +type ChatModelAgentState struct { |
| 74 | + // Messages contains all messages in the current conversation session. |
| 75 | + Messages []Message |
| 76 | +} |
| 77 | + |
| 78 | +type EntranceType string |
| 79 | + |
| 80 | +const ( |
| 81 | + // EntranceTypeRun indicates the agent is starting a new execution from scratch. |
| 82 | + EntranceTypeRun EntranceType = "Run" |
| 83 | + // EntranceTypeResume indicates the agent is resuming a previously interrupted execution. |
| 84 | + EntranceTypeResume EntranceType = "Resume" |
| 85 | +) |
| 86 | + |
| 87 | +// AgentContext contains the context information for an agent's execution. |
| 88 | +// It provides access to input data, resume information, and execution options. |
| 89 | +type AgentContext struct { |
| 90 | + // AgentInput contains the input data for the agent's execution. |
| 91 | + AgentInput *AgentInput |
| 92 | + // ResumeInfo contains information needed to resume a previously interrupted execution. |
| 93 | + ResumeInfo *ResumeInfo |
| 94 | + // AgentRunOptions contains options for configuring the agent's execution. |
| 95 | + AgentRunOptions []AgentRunOption |
| 96 | + |
| 97 | + // internal properties, read only |
| 98 | + agentName string |
| 99 | + entrance EntranceType |
| 100 | +} |
| 101 | + |
| 102 | +func (a *AgentContext) AgentName() string { |
| 103 | + return a.agentName |
| 104 | +} |
| 105 | + |
| 106 | +func (a *AgentContext) EntranceType() EntranceType { |
| 107 | + return a.entrance |
| 108 | +} |
| 109 | + |
| 110 | +// GetGlobalAgentMiddlewares try to get agent middlewares set by RunnerMiddleware |
| 111 | +func GetGlobalAgentMiddlewares(ctx context.Context) []AgentMiddleware { |
| 112 | + if v, ok := ctx.Value(globalAgentMiddlewareCtxKey{}).([]AgentMiddleware); ok && v != nil { |
| 113 | + return v |
| 114 | + } |
| 115 | + return nil |
| 116 | +} |
| 117 | + |
| 118 | +type globalAgentMiddlewareCtxKey struct{} |
| 119 | + |
| 120 | +func isAgentMiddlewareEnabled(a Agent) bool { |
| 121 | + if c, ok := a.(AgentMiddlewareChecker); ok && c.IsAgentMiddlewareEnabled() { |
| 122 | + return true |
| 123 | + } |
| 124 | + return false |
| 125 | +} |
| 126 | + |
| 127 | +type agentMWHelper struct { |
| 128 | + beforeAgentFns []func(ctx context.Context, ac *AgentContext) (nextContext context.Context, err error) |
| 129 | + onEventsFns []func(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) |
| 130 | +} |
| 131 | + |
| 132 | +func (a *agentMWHelper) execBeforeAgents(ctx context.Context, ac *AgentContext) (context.Context, *AsyncIterator[*AgentEvent]) { |
| 133 | + var err error |
| 134 | + for i, beforeAgent := range a.beforeAgentFns { |
| 135 | + if beforeAgent == nil { |
| 136 | + continue |
| 137 | + } |
| 138 | + ctx, err = beforeAgent(ctx, ac) |
| 139 | + if err != nil { |
| 140 | + iter, gen := NewAsyncIteratorPair[*AgentEvent]() |
| 141 | + gen.Send(&AgentEvent{Err: err}) |
| 142 | + gen.Close() |
| 143 | + return ctx, a.execOnEventsFromIndex(ctx, ac, i-1, iter) |
| 144 | + } |
| 145 | + } |
| 146 | + return ctx, nil |
| 147 | +} |
| 148 | + |
| 149 | +func (a *agentMWHelper) execOnEvents(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] { |
| 150 | + return a.execOnEventsFromIndex(ctx, ac, len(a.onEventsFns)-1, iter) |
| 151 | +} |
| 152 | + |
| 153 | +func (a *agentMWHelper) execOnEventsFromIndex(ctx context.Context, ac *AgentContext, fromIdx int, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] { |
| 154 | + for idx := fromIdx; idx >= 0; idx-- { |
| 155 | + onEvents := a.onEventsFns[idx] |
| 156 | + if onEvents == nil { |
| 157 | + continue |
| 158 | + } |
| 159 | + i, g := NewAsyncIteratorPair[*AgentEvent]() |
| 160 | + onEvents(ctx, ac, iter, g) |
| 161 | + iter = i |
| 162 | + } |
| 163 | + return iter |
| 164 | +} |
0 commit comments