Skip to content

Commit 13ae5b9

Browse files
fix(adk): fix concurrent compile race in ChatModelAgent (#775)
Move chain and inner graph creation inside closures to avoid data races when multiple requests call Compile() concurrently on shared instances. Changes: - buildNoToolsRunFunc: create chain per request instead of sharing - buildReactRunFunc: create both inner graph (g) and chain per request Change-Id: I9a44697675c892ac7658909010b72d3a2718c25a
1 parent f887fed commit 13ae5b9

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

adk/chatmodel.go

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -721,22 +721,22 @@ func (a *ChatModelAgent) buildNoToolsRunFunc(_ context.Context) runFunc {
721721
instruction string
722722
}
723723

724-
chain := compose.NewChain[noToolsInput, Message](
725-
compose.WithGenLocalState(func(ctx context.Context) (state *State) {
726-
return &State{}
727-
})).
728-
AppendLambda(compose.InvokableLambda(func(ctx context.Context, in noToolsInput) ([]Message, error) {
729-
messages, err := a.genModelInput(ctx, in.instruction, in.input)
730-
if err != nil {
731-
return nil, err
732-
}
733-
return messages, nil
734-
})).
735-
AppendChatModel(wrappedModel)
736-
737724
return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent],
738725
store *bridgeStore, instruction string, _ map[string]bool, opts ...compose.Option) {
739726

727+
chain := compose.NewChain[noToolsInput, Message](
728+
compose.WithGenLocalState(func(ctx context.Context) (state *State) {
729+
return &State{}
730+
})).
731+
AppendLambda(compose.InvokableLambda(func(ctx context.Context, in noToolsInput) ([]Message, error) {
732+
messages, err := a.genModelInput(ctx, in.instruction, in.input)
733+
if err != nil {
734+
return nil, err
735+
}
736+
return messages, nil
737+
})).
738+
AppendChatModel(wrappedModel)
739+
740740
r, err := chain.Compile(ctx, compose.WithGraphName(a.name),
741741
compose.WithCheckPointStore(store),
742742
compose.WithSerializer(&gobSerializer{}))
@@ -789,32 +789,33 @@ func (a *ChatModelAgent) buildReactRunFunc(ctx context.Context, bc *execContext)
789789
maxIterations: a.maxIterations,
790790
}
791791

792-
g, err := newReact(ctx, conf)
793-
if err != nil {
794-
return nil, err
795-
}
796-
797792
type reactRunInput struct {
798793
input *AgentInput
799794
instruction string
800795
}
801796

802-
chain := compose.NewChain[reactRunInput, Message]().
803-
AppendLambda(
804-
compose.InvokableLambda(func(ctx context.Context, in reactRunInput) (*reactInput, error) {
805-
messages, err := a.genModelInput(ctx, in.instruction, in.input)
806-
if err != nil {
807-
return nil, err
808-
}
809-
return &reactInput{
810-
messages: messages,
811-
}, nil
812-
}),
813-
).
814-
AppendGraph(g, compose.WithNodeName("ReAct"), compose.WithGraphCompileOptions(compose.WithMaxRunSteps(math.MaxInt)))
815-
816797
return func(ctx context.Context, input *AgentInput, generator *AsyncGenerator[*AgentEvent], store *bridgeStore,
817798
instruction string, returnDirectly map[string]bool, opts ...compose.Option) {
799+
g, err := newReact(ctx, conf)
800+
if err != nil {
801+
generator.Send(&AgentEvent{Err: err})
802+
return
803+
}
804+
805+
chain := compose.NewChain[reactRunInput, Message]().
806+
AppendLambda(
807+
compose.InvokableLambda(func(ctx context.Context, in reactRunInput) (*reactInput, error) {
808+
messages, genErr := a.genModelInput(ctx, in.instruction, in.input)
809+
if genErr != nil {
810+
return nil, genErr
811+
}
812+
return &reactInput{
813+
messages: messages,
814+
}, nil
815+
}),
816+
).
817+
AppendGraph(g, compose.WithNodeName("ReAct"), compose.WithGraphCompileOptions(compose.WithMaxRunSteps(math.MaxInt)))
818+
818819
var compileOptions []compose.GraphCompileOption
819820
compileOptions = append(compileOptions,
820821
compose.WithGraphName(a.name),

0 commit comments

Comments
 (0)