Skip to content

Commit e685d86

Browse files
authored
feat(adk): Add SendEvent API for emitting custom events from middleware (#791)
1 parent 65f8955 commit e685d86

File tree

3 files changed

+314
-0
lines changed

3 files changed

+314
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ CLAUDE.md
4848

4949
# Specs directories
5050
*/specs
51+
/todos

adk/chatmodel.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,44 @@ import (
4141
ub "github.com/cloudwego/eino/utils/callbacks"
4242
)
4343

44+
// SendEvent sends a custom AgentEvent to the event stream during agent execution.
45+
// This allows ChatModelAgentMiddleware implementations to emit custom events that will be
46+
// received by the caller iterating over the agent's event stream.
47+
//
48+
// This function can only be called from within a ChatModelAgentMiddleware during agent execution.
49+
// Returns an error if called outside of an agent execution context.
50+
func SendEvent(ctx context.Context, event *AgentEvent) error {
51+
execCtx := getChatModelAgentExecCtx(ctx)
52+
if execCtx == nil || execCtx.generator == nil {
53+
return fmt.Errorf("SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context")
54+
}
55+
execCtx.generator.Send(event)
56+
return nil
57+
}
58+
59+
type chatModelAgentExecCtx struct {
60+
generator *AsyncGenerator[*AgentEvent]
61+
}
62+
63+
func (e *chatModelAgentExecCtx) send(event *AgentEvent) {
64+
if e != nil && e.generator != nil {
65+
e.generator.Send(event)
66+
}
67+
}
68+
69+
type chatModelAgentExecCtxKey struct{}
70+
71+
func withChatModelAgentExecCtx(ctx context.Context, execCtx *chatModelAgentExecCtx) context.Context {
72+
return context.WithValue(ctx, chatModelAgentExecCtxKey{}, execCtx)
73+
}
74+
75+
func getChatModelAgentExecCtx(ctx context.Context) *chatModelAgentExecCtx {
76+
if v := ctx.Value(chatModelAgentExecCtxKey{}); v != nil {
77+
return v.(*chatModelAgentExecCtx)
78+
}
79+
return nil
80+
}
81+
4482
const (
4583
addrDepthChain = 1
4684
addrDepthReactGraph = 2
@@ -799,6 +837,10 @@ func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc {
799837
runOpts = append(runOpts, opts...)
800838
runOpts = append(runOpts, callOpt)
801839

840+
ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{
841+
generator: generator,
842+
})
843+
802844
var msg Message
803845
var msgStream MessageStream
804846
if input.EnableStreaming {
@@ -876,6 +918,10 @@ func (a *ChatModelAgent) buildRunFunc(ctx context.Context) runFunc {
876918
runOpts = append(runOpts, compose.WithToolsNodeOption(compose.WithToolOption(withAgentToolEnableStreaming(true))))
877919
}
878920

921+
ctx = withChatModelAgentExecCtx(ctx, &chatModelAgentExecCtx{
922+
generator: generator,
923+
})
924+
879925
var msg Message
880926
var msgStream MessageStream
881927
if input.EnableStreaming {

adk/chatmodel_test.go

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,3 +1183,270 @@ func (s *simpleToolForMiddlewareTest) InvokableRun(_ context.Context, _ string,
11831183
func (s *simpleToolForMiddlewareTest) StreamableRun(_ context.Context, _ string, _ ...tool.Option) (*schema.StreamReader[string], error) {
11841184
return schema.StreamReaderFromArray([]string{s.result}), nil
11851185
}
1186+
1187+
func TestSendEvent(t *testing.T) {
1188+
t.Run("SendEventWithoutExecCtx", func(t *testing.T) {
1189+
ctx := context.Background()
1190+
event := &AgentEvent{
1191+
Output: &AgentOutput{
1192+
MessageOutput: &MessageVariant{
1193+
Message: schema.AssistantMessage("custom event", nil),
1194+
},
1195+
},
1196+
}
1197+
err := SendEvent(ctx, event)
1198+
assert.Error(t, err)
1199+
assert.Contains(t, err.Error(), "SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context")
1200+
})
1201+
1202+
t.Run("SendEventWithNilGenerator", func(t *testing.T) {
1203+
ctx := context.Background()
1204+
execCtx := &chatModelAgentExecCtx{
1205+
generator: nil,
1206+
}
1207+
ctx = withChatModelAgentExecCtx(ctx, execCtx)
1208+
1209+
event := &AgentEvent{
1210+
Output: &AgentOutput{
1211+
MessageOutput: &MessageVariant{
1212+
Message: schema.AssistantMessage("custom event", nil),
1213+
},
1214+
},
1215+
}
1216+
err := SendEvent(ctx, event)
1217+
assert.Error(t, err)
1218+
assert.Contains(t, err.Error(), "SendEvent failed: must be called within a ChatModelAgent Run() or Resume() execution context")
1219+
})
1220+
1221+
t.Run("SendEventInMiddleware", func(t *testing.T) {
1222+
ctx := context.Background()
1223+
1224+
ctrl := gomock.NewController(t)
1225+
cm := mockModel.NewMockToolCallingChatModel(ctrl)
1226+
1227+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1228+
Return(schema.AssistantMessage("Hello, I am an AI assistant.", nil), nil).
1229+
Times(1)
1230+
1231+
var customEventReceived bool
1232+
customEventContent := "custom_event_from_middleware"
1233+
1234+
agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{
1235+
Name: "TestAgent",
1236+
Description: "Test agent for SendEvent",
1237+
Instruction: "You are a helpful assistant.",
1238+
Model: cm,
1239+
Middlewares: []AgentMiddleware{
1240+
{
1241+
BeforeChatModel: func(ctx context.Context, state *ChatModelAgentState) error {
1242+
customEvent := &AgentEvent{
1243+
Output: &AgentOutput{
1244+
MessageOutput: &MessageVariant{
1245+
Message: schema.AssistantMessage(customEventContent, nil),
1246+
},
1247+
},
1248+
}
1249+
return SendEvent(ctx, customEvent)
1250+
},
1251+
},
1252+
},
1253+
})
1254+
assert.NoError(t, err)
1255+
assert.NotNil(t, agent)
1256+
1257+
input := &AgentInput{
1258+
Messages: []Message{
1259+
schema.UserMessage("Hello"),
1260+
},
1261+
}
1262+
iterator := agent.Run(ctx, input)
1263+
assert.NotNil(t, iterator)
1264+
1265+
for {
1266+
event, ok := iterator.Next()
1267+
if !ok {
1268+
break
1269+
}
1270+
if event.Output != nil && event.Output.MessageOutput != nil &&
1271+
event.Output.MessageOutput.Message != nil &&
1272+
event.Output.MessageOutput.Message.Content == customEventContent {
1273+
customEventReceived = true
1274+
}
1275+
}
1276+
1277+
assert.True(t, customEventReceived, "should receive custom event sent from middleware")
1278+
})
1279+
1280+
t.Run("SendEventInMiddlewareWithTools", func(t *testing.T) {
1281+
ctx := context.Background()
1282+
1283+
ctrl := gomock.NewController(t)
1284+
cm := mockModel.NewMockToolCallingChatModel(ctrl)
1285+
1286+
fakeTool := &fakeToolForTest{
1287+
tarCount: 1,
1288+
}
1289+
info, err := fakeTool.Info(ctx)
1290+
assert.NoError(t, err)
1291+
1292+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1293+
Return(schema.AssistantMessage("Using tool",
1294+
[]schema.ToolCall{
1295+
{
1296+
ID: "tool-call-1",
1297+
Function: schema.FunctionCall{
1298+
Name: info.Name,
1299+
Arguments: `{"name": "test user"}`,
1300+
},
1301+
}}), nil).
1302+
Times(1)
1303+
cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
1304+
Return(schema.AssistantMessage("Task completed", nil), nil).
1305+
Times(1)
1306+
cm.EXPECT().WithTools(gomock.Any()).Return(cm, nil).AnyTimes()
1307+
1308+
var customEventReceived bool
1309+
customEventContent := "custom_event_from_middleware_with_tools"
1310+
1311+
agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{
1312+
Name: "TestAgent",
1313+
Description: "Test agent for SendEvent with tools",
1314+
Instruction: "You are a helpful assistant.",
1315+
Model: cm,
1316+
ToolsConfig: ToolsConfig{
1317+
ToolsNodeConfig: compose.ToolsNodeConfig{
1318+
Tools: []tool.BaseTool{fakeTool},
1319+
},
1320+
},
1321+
Middlewares: []AgentMiddleware{
1322+
{
1323+
BeforeChatModel: func(ctx context.Context, state *ChatModelAgentState) error {
1324+
customEvent := &AgentEvent{
1325+
Output: &AgentOutput{
1326+
MessageOutput: &MessageVariant{
1327+
Message: schema.AssistantMessage(customEventContent, nil),
1328+
},
1329+
},
1330+
}
1331+
return SendEvent(ctx, customEvent)
1332+
},
1333+
},
1334+
},
1335+
})
1336+
assert.NoError(t, err)
1337+
assert.NotNil(t, agent)
1338+
1339+
input := &AgentInput{
1340+
Messages: []Message{
1341+
schema.UserMessage("Use the test tool"),
1342+
},
1343+
}
1344+
iterator := agent.Run(ctx, input)
1345+
assert.NotNil(t, iterator)
1346+
1347+
customEventCount := 0
1348+
for {
1349+
event, ok := iterator.Next()
1350+
if !ok {
1351+
break
1352+
}
1353+
if event.Output != nil && event.Output.MessageOutput != nil &&
1354+
event.Output.MessageOutput.Message != nil &&
1355+
event.Output.MessageOutput.Message.Content == customEventContent {
1356+
customEventReceived = true
1357+
customEventCount++
1358+
}
1359+
}
1360+
1361+
assert.True(t, customEventReceived, "should receive custom event sent from middleware with tools")
1362+
assert.Equal(t, 2, customEventCount, "middleware should be called twice (once for each ChatModel call)")
1363+
})
1364+
1365+
t.Run("SendEventInMiddlewareStreaming", func(t *testing.T) {
1366+
ctx := context.Background()
1367+
1368+
ctrl := gomock.NewController(t)
1369+
cm := mockModel.NewMockToolCallingChatModel(ctrl)
1370+
1371+
sr := schema.StreamReaderFromArray([]*schema.Message{
1372+
schema.AssistantMessage("Hello", nil),
1373+
schema.AssistantMessage(", streaming", nil),
1374+
})
1375+
cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()).
1376+
Return(sr, nil).
1377+
Times(1)
1378+
1379+
var customEventReceived bool
1380+
customEventContent := "custom_event_streaming"
1381+
1382+
agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{
1383+
Name: "TestAgent",
1384+
Description: "Test agent for SendEvent streaming",
1385+
Instruction: "You are a helpful assistant.",
1386+
Model: cm,
1387+
Middlewares: []AgentMiddleware{
1388+
{
1389+
BeforeChatModel: func(ctx context.Context, state *ChatModelAgentState) error {
1390+
customEvent := &AgentEvent{
1391+
Output: &AgentOutput{
1392+
MessageOutput: &MessageVariant{
1393+
Message: schema.AssistantMessage(customEventContent, nil),
1394+
},
1395+
},
1396+
}
1397+
return SendEvent(ctx, customEvent)
1398+
},
1399+
},
1400+
},
1401+
})
1402+
assert.NoError(t, err)
1403+
assert.NotNil(t, agent)
1404+
1405+
input := &AgentInput{
1406+
Messages: []Message{schema.UserMessage("Hello")},
1407+
EnableStreaming: true,
1408+
}
1409+
iterator := agent.Run(ctx, input)
1410+
assert.NotNil(t, iterator)
1411+
1412+
for {
1413+
event, ok := iterator.Next()
1414+
if !ok {
1415+
break
1416+
}
1417+
if event.Output != nil && event.Output.MessageOutput != nil &&
1418+
event.Output.MessageOutput.Message != nil &&
1419+
event.Output.MessageOutput.Message.Content == customEventContent {
1420+
customEventReceived = true
1421+
}
1422+
}
1423+
1424+
assert.True(t, customEventReceived, "should receive custom event in streaming mode")
1425+
})
1426+
}
1427+
1428+
func TestChatModelAgentExecCtx(t *testing.T) {
1429+
t.Run("WithAndGetExecCtx", func(t *testing.T) {
1430+
ctx := context.Background()
1431+
1432+
result := getChatModelAgentExecCtx(ctx)
1433+
assert.Nil(t, result)
1434+
1435+
execCtx := &chatModelAgentExecCtx{}
1436+
ctx = withChatModelAgentExecCtx(ctx, execCtx)
1437+
1438+
result = getChatModelAgentExecCtx(ctx)
1439+
assert.NotNil(t, result)
1440+
assert.Equal(t, execCtx, result)
1441+
})
1442+
1443+
t.Run("ExecCtxSendMethod", func(t *testing.T) {
1444+
var nilExecCtx *chatModelAgentExecCtx
1445+
nilExecCtx.send(&AgentEvent{})
1446+
1447+
execCtxWithNilGenerator := &chatModelAgentExecCtx{
1448+
generator: nil,
1449+
}
1450+
execCtxWithNilGenerator.send(&AgentEvent{})
1451+
})
1452+
}

0 commit comments

Comments
 (0)