Skip to content

Commit 6eb0d0b

Browse files
committed
finish all send requests before quitting pp event-loop to avoid mpi deadlock; synchronize sampler right after async calls to avoid hang
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent b76c987 commit 6eb0d0b

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,12 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
627627
batch_state.sample_state.scheduled_requests), req_stats)
628628

629629
def _executor_loop_cleanup(self):
630+
# Unblock receiving processes. When second-last rank quits before last rank,
631+
# last rank will never return from recv_object.
632+
for req in self.send_handles:
633+
if req is not None:
634+
req.wait()
635+
630636
with self.response_cv:
631637
self.is_shutdown = True
632638
self.response_cv.notify_all()
@@ -750,8 +756,10 @@ def _executor_loop_pp(self):
750756

751757
sample_state = self._sample_async(
752758
scheduled_batch, batch_outputs)
759+
assert sample_state is not None, "Sampling failed"
753760
sample_state.host.logits = logits_host
754761
self._update_request_states(scheduled_batch)
762+
sample_state.sampler_event.synchronize()
755763

756764
if self.enable_iter_perf_stats:
757765
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,22 @@ def run_disaggregated_test(example_dir,
159159
run_env = env.copy()
160160
run_env["UCX_TLS"] = "^ib"
161161

162+
nsys_path = os.getenv("NSYS_PATH", None)
163+
nsys_file = os.getenv("NSYS_FILE", None)
164+
nsys_cmd = [
165+
"nsys",
166+
"profile",
167+
"--trace",
168+
"cuda,cublas,nvtx",
169+
"--output",
170+
nsys_file,
171+
"--force-overwrite=true",
172+
] if nsys_path and nsys_file else []
173+
162174
num_ranks, config_file = get_test_config(test_desc, example_dir,
163175
os.path.dirname(__file__))
164176

165-
workers_cmd = [
177+
workers_cmd = nsys_cmd + [
166178
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
167179
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
168180
config_file

0 commit comments

Comments
 (0)