Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 9 additions & 48 deletions cmd/root/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,14 @@ import (
"github.com/docker/cagent/pkg/telemetry"
)

// shouldMonitorStdin determines if we should monitor stdin for parent process death.
// This is only meaningful when:
// 1. We're not running as PID 1 or direct child of init (ppid > 1)
// 2. stdin is a pipe (indicating we were spawned by a parent with piped stdio)
//
// In containers, stdin is typically /dev/null or closed, so we skip monitoring
// to avoid immediate shutdown.
func shouldMonitorStdin(ppid int, stdin *os.File) bool {
// Skip if running as PID 1 or direct child of init (common in containers/systemd)
if ppid <= 1 {
slog.Debug("Skipping stdin monitor: running as init or direct child of init", "ppid", ppid)
return false
}

if stdin == nil {
return false
}

// Check if stdin is a pipe
fi, err := stdin.Stat()
if err != nil {
slog.Debug("Skipping stdin monitor: cannot stat stdin", "error", err)
return false
}

// Only monitor if stdin is a pipe (parent process has piped stdio to us)
if fi.Mode()&os.ModeNamedPipe == 0 {
slog.Debug("Skipping stdin monitor: stdin is not a pipe", "mode", fi.Mode())
return false
}

slog.Debug("Enabling stdin monitor: stdin is a pipe from parent process", "ppid", ppid)
return true
}

type apiFlags struct {
listenAddr string
sessionDB string
pullIntervalMins int
fakeResponses string
recordPath string
connectRPC bool
exitOnStdinEOF bool
runConfig config.RuntimeConfig
}

Expand All @@ -80,6 +46,8 @@ func newAPICmd() *cobra.Command {
cmd.PersistentFlags().StringVar(&flags.fakeResponses, "fake", "", "Replay AI responses from cassette file (for testing)")
cmd.PersistentFlags().StringVar(&flags.recordPath, "record", "", "Record AI API interactions to cassette file")
cmd.PersistentFlags().BoolVar(&flags.connectRPC, "connect-rpc", false, "Use Connect-RPC protocol instead of HTTP/JSON API")
cmd.PersistentFlags().BoolVar(&flags.exitOnStdinEOF, "exit-on-stdin-eof", false, "Exit when stdin is closed (for integration with parent processes)")
_ = cmd.PersistentFlags().MarkHidden("exit-on-stdin-eof")
cmd.MarkFlagsMutuallyExclusive("fake", "record")
addRuntimeConfigFlags(cmd, &flags.runConfig)

Expand Down Expand Up @@ -116,26 +84,19 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {

ctx := cmd.Context()

// Save stdin before redirecting it, so we can optionally monitor for parent death
// Save stdin before clearing it, so we can optionally monitor for parent process death
stdin := os.Stdin

out := cli.NewPrinter(cmd.OutOrStdout())
agentsPath := args[0]

// Redirect stdin to /dev/null to prevent interactive prompts in API mode.
// We use /dev/null instead of nil to avoid panics in code that calls os.Stdin.Fd().
devNull, err := os.Open(os.DevNull)
if err != nil {
slog.Warn("Failed to open /dev/null, setting stdin to nil", "error", err)
} else {
os.Stdin = devNull
defer devNull.Close()
}
// Make sure no question is ever asked to the user in api mode.
os.Stdin = nil

// Monitor stdin for EOF to detect parent process death.
// Only enabled when stdin is a pipe (indicating we were spawned by a parent process).
// In containers, stdin is typically /dev/null or closed, so we skip monitoring.
if shouldMonitorStdin(os.Getppid(), stdin) {
// Only enabled when --exit-on-stdin-eof flag is passed.
// When spawned with piped stdio, stdin closes when the parent process dies.
if f.exitOnStdinEOF && stdin != nil {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
Expand Down
91 changes: 17 additions & 74 deletions cmd/root/api_test.go
Original file line number Diff line number Diff line change
@@ -1,102 +1,45 @@
package root

import (
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestShouldMonitorStdin(t *testing.T) {
func TestAPICommand_ExitOnStdinEOFFlag(t *testing.T) {
t.Parallel()

t.Run("returns false when ppid is 0", func(t *testing.T) {
t.Run("flag exists and defaults to false", func(t *testing.T) {
t.Parallel()

// Create a pipe to simulate stdin from parent
r, w, err := os.Pipe()
require.NoError(t, err)
defer r.Close()
defer w.Close()

// ppid=0 means we're init or something weird - should not monitor
result := shouldMonitorStdin(0, r)
assert.False(t, result, "should not monitor stdin when ppid is 0")
})

t.Run("returns false when ppid is 1", func(t *testing.T) {
t.Parallel()

// Create a pipe to simulate stdin from parent
r, w, err := os.Pipe()
require.NoError(t, err)
defer r.Close()
defer w.Close()

// ppid=1 means parent is init (common in containers) - should not monitor
result := shouldMonitorStdin(1, r)
assert.False(t, result, "should not monitor stdin when ppid is 1 (init)")
})
cmd := newAPICmd()

t.Run("returns false when stdin is nil", func(t *testing.T) {
t.Parallel()

result := shouldMonitorStdin(123, nil)
assert.False(t, result, "should not monitor stdin when stdin is nil")
flag := cmd.PersistentFlags().Lookup("exit-on-stdin-eof")
require.NotNil(t, flag, "exit-on-stdin-eof flag should exist")
assert.Equal(t, "false", flag.DefValue, "exit-on-stdin-eof should default to false")
})

t.Run("returns false when stdin is /dev/null", func(t *testing.T) {
t.Run("flag is hidden", func(t *testing.T) {
t.Parallel()

devNull, err := os.Open(os.DevNull)
require.NoError(t, err)
defer devNull.Close()
cmd := newAPICmd()

// /dev/null is not a pipe, so should not monitor
result := shouldMonitorStdin(123, devNull)
assert.False(t, result, "should not monitor stdin when stdin is /dev/null")
flag := cmd.PersistentFlags().Lookup("exit-on-stdin-eof")
require.NotNil(t, flag, "exit-on-stdin-eof flag should exist")
assert.True(t, flag.Hidden, "exit-on-stdin-eof flag should be hidden")
})

t.Run("returns false when stdin is a regular file", func(t *testing.T) {
t.Run("flag can be set to true", func(t *testing.T) {
t.Parallel()

// Create a temp file
f, err := os.CreateTemp(t.TempDir(), "test-stdin-*")
require.NoError(t, err)
defer f.Close()

// Regular file is not a pipe, so should not monitor
result := shouldMonitorStdin(123, f)
assert.False(t, result, "should not monitor stdin when stdin is a regular file")
})

t.Run("returns true when stdin is a pipe and ppid > 1", func(t *testing.T) {
t.Parallel()

// Create a pipe to simulate stdin from parent
r, w, err := os.Pipe()
require.NoError(t, err)
defer r.Close()
defer w.Close()

// ppid > 1 and stdin is a pipe - should monitor
result := shouldMonitorStdin(123, r)
assert.True(t, result, "should monitor stdin when stdin is a pipe and ppid > 1")
})

t.Run("returns true with various valid ppids", func(t *testing.T) {
t.Parallel()
cmd := newAPICmd()

r, w, err := os.Pipe()
err := cmd.PersistentFlags().Set("exit-on-stdin-eof", "true")
require.NoError(t, err)
defer r.Close()
defer w.Close()

// Test various ppid values > 1
for _, ppid := range []int{2, 100, 1000, 65535} {
result := shouldMonitorStdin(ppid, r)
assert.True(t, result, "should monitor stdin when ppid is %d", ppid)
}
flag := cmd.PersistentFlags().Lookup("exit-on-stdin-eof")
require.NotNil(t, flag)
assert.Equal(t, "true", flag.Value.String())
})
}