-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-6835][fix] Fix potential hang caused by python multiprocessing when prefetching weights #6927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Lance Liao <[email protected]>
📝 WalkthroughWalkthroughReplaced multiprocessing-based prefetching with ThreadPoolExecutor in the HF checkpoint weight loader. Updated parameter naming to max_workers and ensured tasks complete before return via list(executor.map(...)). Worker count calculation and surrounding logic, including barriers, remain unchanged. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py (3)
124-127: Drop multiprocessing dependency; use os.cpu_count() and keep a safe fallback.Since processes are no longer used, rely on os.cpu_count (already imported via os) and provide a fallback to 1 in case it returns None. This lets you remove the multiprocessing import entirely and slightly simplifies dependencies.
Apply this diff within this block:
- max_workers = min(multiprocessing.cpu_count() * 2, 16, - len(local_file_names)) + cpu_count = os.cpu_count() or 1 + max_workers = min(cpu_count * 2, 16, len(local_file_names)) with ThreadPoolExecutor(max_workers=max_workers) as executor: list(executor.map(self._prefetch_one_file, local_file_names))Additionally (outside this block), remove the now-unused import:
# Remove this at the top of the file import multiprocessing
1-1: Add NVIDIA copyright header (current year).Per coding guidelines, prepend the header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import glob
27-27: Python 3.8 compatibility: avoid PEP 585 built-in generics without postponed evaluation.The return annotations use dict[str, Any], which requires Python 3.9 unless annotations are postponed. If Python 3.8 is in scope (per guidelines), either add from future import annotations at the top or switch to typing.Dict[str, Any].
Two options:
- Preferred: Postpone annotations at file top (place before other imports):
from __future__ import annotations
- Or, change annotations to:
def load_weights(self, checkpoint_dir: str) -> Dict[str, Any]: ... def _load_weights_in_parallel(self, weight_files: List[str], load_func, description: str) -> Dict[str, Any]: ...Please confirm the minimum supported Python version for this module. If it’s >=3.9, you can ignore this.
Also applies to: 61-63
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
🔇 Additional comments (1)
tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py (1)
4-4: Good call switching to threads for I/O-bound prefetch; avoids fork/MPI pitfalls.Using ThreadPoolExecutor here is safer in MPI environments and keeps performance for I/O (file reads release the GIL). Consuming the iterator via list(...) ensures all tasks complete before the context exits.
|
/bot run |
|
PR_Github #15391 [ run ] triggered by Bot |
djns99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I have confirmed this appears to fix the issue
|
PR_Github #15391 [ run ] completed with state |
|
/bot run |
|
PR_Github #15410 [ run ] triggered by Bot |
|
PR_Github #15410 [ run ] completed with state |
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (NVIDIA#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
…g when prefetching weights (#6927) Signed-off-by: Lance Liao <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
This PR replaces multiprocessing with multithreading when prefetching weights, since the workload is I/O-bound. Python’s default fork start method for multiprocessing carries a risk of hangs (see python/cpython#84559), and when used together with MPI it can deadlock. In theory, multithreading is the safer choice; we have verified on an 8-GPU node that loading DeepSeek-R1 takes no longer than before. See the following pics:


Summary by CodeRabbit