Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cmd/root/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,11 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadRes

// Apply any stored model overrides from the session
if len(sess.AgentModelOverrides) > 0 {
for agentName, modelRef := range sess.AgentModelOverrides {
if err := localRt.SetAgentModel(ctx, agentName, modelRef); err != nil {
slog.Warn("Failed to apply stored model override", "agent", agentName, "model", modelRef, "error", err)
if modelSwitcher, ok := localRt.(runtime.ModelSwitcher); ok {
for agentName, modelRef := range sess.AgentModelOverrides {
if err := modelSwitcher.SetAgentModel(ctx, agentName, modelRef); err != nil {
slog.Warn("Failed to apply stored model override", "agent", agentName, "model", modelRef, "error", err)
}
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion e2e/cagent_exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ func cagentExec(t *testing.T, moreArgs ...string) string {

// Start a recording AI proxy to record and replay traffic.
svr, _ := startRecordingAIProxy(t)
args = append(args, "--models-gateway", svr.URL, "--session-db", "/tmp/session.db")
// Use a unique session DB path per test to avoid conflicts when tests run in parallel
sessionDB := filepath.Join(t.TempDir(), "session.db")
args = append(args, "--models-gateway", svr.URL, "--session-db", sessionDB)

// Run cagent exec
var stdout bytes.Buffer
Expand Down
53 changes: 48 additions & 5 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/docker/cagent/pkg/chat"
"github.com/docker/cagent/pkg/config/types"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/tools"
)

Expand All @@ -22,16 +23,18 @@ func (a AgentContext) GetAgentName() string { return a.AgentName }

// UserMessageEvent is sent when a user message is received
type UserMessageEvent struct {
Type string `json:"type"`
Message string `json:"message"`
Type string `json:"type"`
Message string `json:"message"`
SessionID string `json:"session_id"`
}

func (e *UserMessageEvent) GetAgentName() string { return "" }

func UserMessage(message string) Event {
func UserMessage(message, sessionID string) Event {
return &UserMessageEvent{
Type: "user_message",
Message: message,
Type: "user_message",
Message: message,
SessionID: sessionID,
}
}

Expand Down Expand Up @@ -526,3 +529,43 @@ func HookBlocked(toolCall tools.ToolCall, toolDefinition tools.Tool, message, ag
AgentContext: AgentContext{AgentName: agentName},
}
}

// MessageAddedEvent is emitted when a message is added to the session.
// This event is used by the PersistentRuntime wrapper to persist messages.
type MessageAddedEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Message *session.Message `json:"-"`
AgentContext
}

func (e *MessageAddedEvent) GetAgentName() string { return e.AgentName }

func MessageAdded(sessionID string, msg *session.Message, agentName string) Event {
return &MessageAddedEvent{
Type: "message_added",
SessionID: sessionID,
Message: msg,
AgentContext: AgentContext{AgentName: agentName},
}
}

// SubSessionCompletedEvent is emitted when a sub-session completes and is added to parent.
// This event is used by the PersistentRuntime wrapper to persist sub-sessions.
type SubSessionCompletedEvent struct {
Type string `json:"type"`
ParentSessionID string `json:"parent_session_id"`
SubSession any `json:"sub_session"` // *session.Session
AgentContext
}

func (e *SubSessionCompletedEvent) GetAgentName() string { return e.AgentName }

func SubSessionCompleted(parentSessionID string, subSession any, agentName string) Event {
return &SubSessionCompletedEvent{
Type: "sub_session_completed",
ParentSessionID: parentSessionID,
SubSession: subSession,
AgentContext: AgentContext{AgentName: agentName},
}
}
185 changes: 185 additions & 0 deletions pkg/runtime/persistent_runtime.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package runtime

import (
"context"
"fmt"
"log/slog"
"strings"

"github.com/docker/cagent/pkg/chat"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/team"
)

// PersistentRuntime wraps a LocalRuntime and persists session changes to a store
// based on emitted events.
type PersistentRuntime struct {
*LocalRuntime
}

// streamingState tracks the accumulated content for a streaming assistant message
type streamingState struct {
content strings.Builder
reasoningContent strings.Builder
agentName string
messageID int64 // ID of the current streaming message (0 if none)
}

// New creates a new runtime for an agent and its team.
// The runtime automatically persists session changes to the configured store.
// Returns a Runtime interface which wraps LocalRuntime with persistence handling.
func New(agents *team.Team, opts ...Opt) (Runtime, error) {
r, err := NewLocalRuntime(agents, opts...)
if err != nil {
return nil, err
}

return &PersistentRuntime{
LocalRuntime: r,
}, nil
}

// RunStream wraps the inner runtime's RunStream and intercepts events
// to persist session changes to the store.
func (r *PersistentRuntime) RunStream(ctx context.Context, sess *session.Session) <-chan Event {
if !sess.IsSubSession() {
if err := r.sessionStore.UpdateSession(ctx, sess); err != nil {
slog.Warn("Failed to persist initial session", "session_id", sess.ID, "error", err)
}
}

innerEvents := r.LocalRuntime.RunStream(ctx, sess)
events := make(chan Event, 128)

go func() {
defer close(events)

streaming := &streamingState{}

for event := range innerEvents {
r.handleEvent(ctx, sess, event, streaming)
events <- event
}
}()

return events
}

func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Session, event Event, streaming *streamingState) {
// Skip persistence for sub-sessions (they're persisted when added to parent)
if sess.IsSubSession() {
return
}

switch e := event.(type) {
case *AgentChoiceEvent:
// Accumulate streaming content
streaming.content.WriteString(e.Content)
streaming.agentName = e.AgentName

r.persistStreamingContent(ctx, sess.ID, streaming)

case *AgentChoiceReasoningEvent:
// Accumulate streaming reasoning content
streaming.reasoningContent.WriteString(e.Content)
streaming.agentName = e.AgentName

r.persistStreamingContent(ctx, sess.ID, streaming)

case *UserMessageEvent:
// Reset streaming state when a user message is received
streaming.content.Reset()
streaming.reasoningContent.Reset()
streaming.agentName = ""
streaming.messageID = 0

if _, err := r.sessionStore.AddMessage(ctx, e.SessionID, session.UserMessage(e.Message)); err != nil {
slog.Warn("Failed to persist user message", "session_id", e.SessionID, "error", err)
}

case *MessageAddedEvent:
// Finalize the streaming message with complete metadata
if streaming.messageID != 0 {
// Update the existing streaming message with final content
if err := r.sessionStore.UpdateMessage(ctx, streaming.messageID, e.Message); err != nil {
slog.Warn("Failed to finalize streaming message", "session_id", e.SessionID, "message_id", streaming.messageID, "error", err)
}
} else {
// No streaming message exists, create a new one
if _, err := r.sessionStore.AddMessage(ctx, e.SessionID, e.Message); err != nil {
slog.Warn("Failed to persist message", "session_id", e.SessionID, "error", err)
}
}

// Reset streaming state after message is finalized
streaming.content.Reset()
streaming.reasoningContent.Reset()
streaming.agentName = ""
streaming.messageID = 0

case *SubSessionCompletedEvent:
if subSess, ok := e.SubSession.(*session.Session); ok {
if err := r.sessionStore.AddSubSession(ctx, e.ParentSessionID, subSess); err != nil {
slog.Warn("Failed to persist sub-session", "parent_id", e.ParentSessionID, "error", err)
}
}

case *SessionSummaryEvent:
if err := r.sessionStore.AddSummary(ctx, e.SessionID, e.Summary); err != nil {
slog.Warn("Failed to persist summary", "session_id", e.SessionID, "error", err)
}

case *TokenUsageEvent:
if e.Usage != nil {
if err := r.sessionStore.UpdateSessionTokens(ctx, sess.ID, e.Usage.InputTokens, e.Usage.OutputTokens, e.Usage.Cost); err != nil {
slog.Warn("Failed to persist token usage", "session_id", sess.ID, "error", err)
}
}

case *SessionTitleEvent:
if err := r.sessionStore.UpdateSessionTitle(ctx, sess.ID, e.Title); err != nil {
slog.Warn("Failed to persist session title", "session_id", sess.ID, "error", err)
}
}
}

// persistStreamingContent creates or updates the streaming assistant message
func (r *PersistentRuntime) persistStreamingContent(ctx context.Context, sessionID string, streaming *streamingState) {
msg := &session.Message{
AgentName: streaming.agentName,
Message: chat.Message{
Role: chat.MessageRoleAssistant,
Content: streaming.content.String(),
ReasoningContent: streaming.reasoningContent.String(),
},
}

if streaming.messageID == 0 {
// Create new streaming message
id, err := r.sessionStore.AddMessage(ctx, sessionID, msg)
if err != nil {
slog.Warn("Failed to create streaming message", "session_id", sessionID, "error", err)
return
}
streaming.messageID = id
slog.Debug("[PERSIST] Created streaming message", "session_id", sessionID, "message_id", id, "agent", streaming.agentName)
} else {
// Update existing streaming message
if err := r.sessionStore.UpdateMessage(ctx, streaming.messageID, msg); err != nil {
slog.Warn("Failed to update streaming message", "session_id", sessionID, "message_id", streaming.messageID, "error", err)
}
}
}

// Run wraps the inner runtime's Run method
func (r *PersistentRuntime) Run(ctx context.Context, sess *session.Session) ([]session.Message, error) {
eventsChan := r.RunStream(ctx, sess)

for event := range eventsChan {
if errEvent, ok := event.(*ErrorEvent); ok {
return nil, fmt.Errorf("%s", errEvent.Error)
}
}

return sess.GetAllMessages(), nil
}
40 changes: 17 additions & 23 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ func WithEnv(env []string) Opt {
}
}

// New creates a new runtime for an agent and its team
func New(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
// NewLocalRuntime creates a new LocalRuntime without the persistence wrapper.
// This is useful for testing or when persistence is handled externally.
func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
modelsStore, err := modelsdev.NewStore()
if err != nil {
return nil, err
Expand Down Expand Up @@ -288,7 +289,7 @@ func New(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
}

r.titleGen = newTitleGenerator(model)
r.sessionCompactor = newSessionCompactor(model, r.sessionStore)
r.sessionCompactor = newSessionCompactor(model)

slog.Debug("Creating new runtime", "agent", r.currentAgent, "available_agents", agents.Size())

Expand Down Expand Up @@ -724,7 +725,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

messages := sess.GetMessages(a)
if sess.SendUserMessage {
events <- UserMessage(messages[len(messages)-1].Content)
events <- UserMessage(messages[len(messages)-1].Content, sess.ID)
}

events <- StreamStarted(sess.ID, a.Name())
Expand Down Expand Up @@ -784,8 +785,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
CreatedAt: time.Now().Format(time.RFC3339),
}

sess.AddMessage(session.NewAgentMessage(a, &assistantMessage))
r.saveSession(ctx, sess)
addAgentMessage(sess, a, &assistantMessage, events)
return
}

Expand Down Expand Up @@ -951,8 +951,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
msgUsage.RateLimit = *res.RateLimit
}

sess.AddMessage(session.NewAgentMessage(a, &assistantMessage))
r.saveSession(ctx, sess)
addAgentMessage(sess, a, &assistantMessage, events)
slog.Debug("Added assistant message to session", "agent", a.Name(), "total_messages", len(sess.GetAllMessages()))
} else {
slog.Debug("Skipping empty assistant message (no content and no tool calls)", "agent", a.Name())
Expand Down Expand Up @@ -1511,8 +1510,7 @@ func (r *LocalRuntime) executeToolWithHandler(
ToolCallID: toolCall.ID,
CreatedAt: time.Now().Format(time.RFC3339),
}
sess.AddMessage(session.NewAgentMessage(a, &toolResponseMsg))
r.saveSession(ctx, sess)
addAgentMessage(sess, a, &toolResponseMsg, events)
}

// runTool executes agent tools from toolsets (MCP, filesystem, etc.).
Expand Down Expand Up @@ -1591,9 +1589,15 @@ func (r *LocalRuntime) runAgentTool(ctx context.Context, handler ToolHandlerFunc
})
}

func addAgentMessage(sess *session.Session, a *agent.Agent, msg *chat.Message, events chan Event) {
agentMsg := session.NewAgentMessage(a, msg)
sess.AddMessage(agentMsg)
events <- MessageAdded(sess.ID, agentMsg, a.Name())
}

// addToolErrorResponse adds a tool error response to the session and emits the event.
// This consolidates the common pattern used by validation, rejection, and cancellation responses.
func (r *LocalRuntime) addToolErrorResponse(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, tool tools.Tool, events chan Event, a *agent.Agent, errorMsg string) {
func (r *LocalRuntime) addToolErrorResponse(_ context.Context, sess *session.Session, toolCall tools.ToolCall, tool tools.Tool, events chan Event, a *agent.Agent, errorMsg string) {
events <- ToolCallResponse(toolCall, tool, tools.ResultError(errorMsg), errorMsg, a.Name())

toolResponseMsg := chat.Message{
Expand All @@ -1602,18 +1606,7 @@ func (r *LocalRuntime) addToolErrorResponse(ctx context.Context, sess *session.S
ToolCallID: toolCall.ID,
CreatedAt: time.Now().Format(time.RFC3339),
}
sess.AddMessage(session.NewAgentMessage(a, &toolResponseMsg))
r.saveSession(ctx, sess)
}

// saveSession persists the session to the store, but only for root sessions.
// Sub-sessions (those with a ParentID) are not persisted as standalone entries;
// they are embedded within the parent session's Messages array.
func (r *LocalRuntime) saveSession(ctx context.Context, sess *session.Session) {
if sess.IsSubSession() {
return
}
_ = r.sessionStore.UpdateSession(ctx, sess)
addAgentMessage(sess, a, &toolResponseMsg, events)
}

// startSpan wraps tracer.Start, returning a no-op span if the tracer is nil.
Expand Down Expand Up @@ -1707,6 +1700,7 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses
sess.Thinking = s.Thinking

sess.AddSubSession(s)
evts <- SubSessionCompleted(sess.ID, s, a.Name())

slog.Debug("Task transfer completed", "agent", params.Agent, "task", params.Task)

Expand Down
Loading