Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions lerobot/common/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,68 @@
# limitations under the License.

import logging
import os
import signal
import sys

shutdown_event_counter = 0

class ProcessSignalHandler:
"""Utility class to attach graceful shutdown signal handlers.

def setup_process_handlers(use_threads: bool) -> any:
if use_threads:
from threading import Event
else:
from multiprocessing import Event
The class exposes a shutdown_event attribute that is set when a shutdown
signal is received. A counter tracks how many shutdown signals have been
caught. On the second signal the process exits with status 1.
"""

shutdown_event = Event()
_SUPPORTED_SIGNALS = ("SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT")

# Define signal handler
def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
global shutdown_event_counter
shutdown_event_counter += 1
def __init__(self, use_threads: bool, display_pid: bool = False):
# TODO: Check if we can use Event from threading since Event from
# multiprocessing is the a clone of threading.Event.
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Event
if use_threads:
from threading import Event
else:
from multiprocessing import Event

if shutdown_event_counter > 1:
logging.info("Force shutdown")
sys.exit(1)
self.shutdown_event = Event()
self._counter: int = 0
self._display_pid = display_pid

signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
self._register_handlers()

def signal_handler(signum, frame):
logging.info("Shutdown signal received. Cleaning up...")
shutdown_event.set()
@property
def counter(self) -> int: # pragma: no cover – simple accessor
"""Number of shutdown signals that have been intercepted."""
return self._counter

return shutdown_event
def _register_handlers(self):
"""Attach the internal _signal_handler to a subset of POSIX signals."""

def _signal_handler(signum, frame):
pid_str = ""
if self._display_pid:
pid_str = f"[PID: {os.getpid()}]"
logging.info(f"{pid_str} Shutdown signal {signum} received. Cleaning up…")
self.shutdown_event.set()
self._counter += 1

# On a second Ctrl-C (or any supported signal) force the exit to
# mimic the previous behaviour while giving the caller one chance to
# shutdown gracefully.
# TODO: Investigate if we need it later
if self._counter > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any strong reason why you need a second signal to shutdown?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem that scripts sometimes stuck and doesn't stop. Or do that very slow - like minutes. We want to kill process immediately if user send ctrl+c many times.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then why ignore the first ctlr+c?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't ignore.

The flow is following - you press ctrl+c one time - than the shutdown event is set and the the process stops gracefully. It write logs, stop servers, closing connections and etc. It how it should work. it'show it works now with Actor Learner in 90% of cases. Also, it how we want it to work. The default flow

But sometimes something happen, maybe some bugs, or some issue in pytorch, or issue in multiprocesses setup and some of processes or threads are stuck. Like for minutes, or foreverer. In such case you will be seating and watching locked terminal. So, as every user u will try to press ctrl+c several times. We will recognize it and kill the process hardly.

Without any signal handling - we won't be able to stop the process (processes) at all, without this trick we won't be able to stop processes in case of some bugs or breaking changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smelly IMHO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imstevenpmwork what is the best from your opinion to solve it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the reason you guys need to ignore the first CTRL+C (but set the shutdown event) is to ensure synchronization between processes. If Process A terminates immediately, other processes (B and C) won't be able to check the shutdown_event status reliably, leading to undefined behavior.

  1. Process A: Registers signals and creates the shutdown event
  2. Processes B & C: Continuously monitor shutdown_event.is_set()

If Process A dies during the first signal, Processes B and C can't check its status. The delayed termination you observe likely occurs because the system relies on garbage collection to clean up these processes.

I think, for example, that the approach below would have the same effect as ignoring the first signal:

self.shutdown_event.set()
time.sleep(10)  # Give Processes B & C time to detect the event and shutdown cleanly
sys.exit(1) # Process B & C had the time to respond to the shutdown signal before Process A terminates

However, in my opinion, this would not be a good solution either. The best way to proceed is to just have a rigorous thread/process synchronization communication which is not straight forward

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. This could be good too, yes.

Just to give the context how we came up to the current solution:
1 - We started just with shutdown_event.set() in the beginning which was perfectly worked in multithreading version
2 - After this we added multiprocessing version (here is two options - fork, spawning). And at some point when we had bugs in the flow, or some internal implementation of multiprocessing queues stuck in some cases terminals just were frozen. If u were stiing and pressing ctrl+c infinite amount of time. So we cam up with an idea of catching such cases, and closing harddly all processes.

Theoretically - yeah, we can always do sys.exit(1). But in such case we will never know that exist some issue with stuck processes in the system.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it seems that systems will be more complex - as we will add some planners, supervisors, we will add locomotion parts, and probably remote inference as default way to produce actions. So, would be good to have a good basic system.

logging.info("Force shutdown")
sys.exit(1)

for sig_name in self._SUPPORTED_SIGNALS:
sig = getattr(signal, sig_name, None)
if sig is None:
# The signal is not available on this platform (Windows for
# instance does not provide SIGHUP, SIGQUIT…). Skip it.
continue
try:
signal.signal(sig, _signal_handler)
except (ValueError, OSError): # pragma: no cover – unlikely but safe
# Signal not supported or we are in a non-main thread.
continue
13 changes: 5 additions & 8 deletions lerobot/scripts/rl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.common.utils.process import setup_process_handlers
from lerobot.common.utils.process import ProcessSignalHandler
from lerobot.common.utils.queue import get_last_item_from_queue
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.robot_utils import busy_wait
Expand Down Expand Up @@ -139,7 +139,8 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Actor logging initialized, writing to {log_file}")

shutdown_event = setup_process_handlers(use_threads(cfg))
is_threaded = use_threads(cfg)
shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event

learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
Expand Down Expand Up @@ -491,7 +492,7 @@ def receive_policy(

# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(use_threads=False)
_ = ProcessSignalHandler(use_threads=False, display_pid=True)

if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
Expand Down Expand Up @@ -544,10 +545,6 @@ def send_transitions(
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor transitions process logging initialized")

# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)

if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
Expand Down Expand Up @@ -599,7 +596,7 @@ def send_interactions(

# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
_ = ProcessSignalHandler(use_threads=False, display_pid=True)

if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
Expand Down
8 changes: 5 additions & 3 deletions lerobot/scripts/rl/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
state_to_bytes,
)
from lerobot.common.utils.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.common.utils.process import setup_process_handlers
from lerobot.common.utils.process import ProcessSignalHandler
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
Expand Down Expand Up @@ -203,7 +203,8 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

shutdown_event = setup_process_handlers(use_threads(cfg))
is_threaded = use_threads(cfg)
shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event

start_learner_threads(
cfg=cfg,
Expand Down Expand Up @@ -673,7 +674,8 @@ def start_learner(
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
# Return back for MP
setup_process_handlers(False)
# TODO: Check if its useful
_ = ProcessSignalHandler(False, display_pid=True)

service = learner_service.LearnerService(
shutdown_event=shutdown_event,
Expand Down
28 changes: 13 additions & 15 deletions tests/utils/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pytest

from lerobot.common.utils.process import setup_process_handlers
from lerobot.common.utils.process import ProcessSignalHandler


# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
Expand All @@ -34,30 +34,26 @@ def reset_globals_and_handlers():
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT]
if hasattr(signal, sig.name)
}
# Reset counter from the module
import lerobot.common.utils.process

lerobot.common.utils.process.shutdown_event_counter = 0

yield

# Restore original signal handlers
for sig, handler in original_handlers.items():
signal.signal(sig, handler)
# Reset counter again to be safe
lerobot.common.utils.process.shutdown_event_counter = 0


def test_setup_process_handlers_event_with_threads():
"""Test that setup_process_handlers returns the correct event type."""
shutdown_event = setup_process_handlers(use_threads=True)
handler = ProcessSignalHandler(use_threads=True)
shutdown_event = handler.shutdown_event
assert isinstance(shutdown_event, threading.Event), "Should be a threading.Event"
assert not shutdown_event.is_set(), "Event should initially be unset"


def test_setup_process_handlers_event_with_processes():
"""Test that setup_process_handlers returns the correct event type."""
shutdown_event = setup_process_handlers(use_threads=False)
handler = ProcessSignalHandler(use_threads=False)
shutdown_event = handler.shutdown_event
assert isinstance(shutdown_event, type(multiprocessing.Event())), "Should be a multiprocessing.Event"
assert not shutdown_event.is_set(), "Event should initially be unset"

Expand All @@ -81,10 +77,10 @@ def test_setup_process_handlers_event_with_processes():
)
def test_signal_handler_sets_event(use_threads, sig):
"""Test that the signal handler sets the event on receiving a signal."""
shutdown_event = setup_process_handlers(use_threads=use_threads)
import lerobot.common.utils.process
handler = ProcessSignalHandler(use_threads=use_threads)
shutdown_event = handler.shutdown_event

assert lerobot.common.utils.process.shutdown_event_counter == 0
assert handler.counter == 0

os.kill(os.getpid(), sig)

Expand All @@ -93,13 +89,15 @@ def test_signal_handler_sets_event(use_threads, sig):

assert shutdown_event.is_set(), f"Event should be set after receiving signal {sig}"

# Ensure the internal counter was incremented
assert handler.counter == 1


@pytest.mark.parametrize("use_threads", [True, False])
@patch("sys.exit")
def test_force_shutdown_on_second_signal(mock_sys_exit, use_threads):
"""Test that a second signal triggers a force shutdown."""
setup_process_handlers(use_threads=use_threads)
import lerobot.common.utils.process
handler = ProcessSignalHandler(use_threads=use_threads)

os.kill(os.getpid(), signal.SIGINT)
# Give a moment for the first signal to be processed
Expand All @@ -110,5 +108,5 @@ def test_force_shutdown_on_second_signal(mock_sys_exit, use_threads):

time.sleep(0.1)

assert lerobot.common.utils.process.shutdown_event_counter == 2
assert handler.counter == 2
mock_sys_exit.assert_called_once_with(1)
Loading