-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[nvbug 5380101][fix] Fix nemotronNAS loading for TP>1 #6447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…d uses head_dim to get number of kv heads for the specific module Signed-off-by: Tomer Asida <[email protected]>
…en3 the special case. Now VGQA works for the general case and no need for special code for NemotronNAS Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…mpute num_kv_heads from head_dim and weight shape in standard flow Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
📝 WalkthroughWalkthroughThe changes update type annotations and initialization logic in weight mapper classes, refine key-value head handling, and simplify key-value weight duplication logic. Additionally, a waiver for a specific integration test is removed, enabling its execution. Changes
Sequence Diagram(s)sequenceDiagram
participant Tester
participant WaiverList
Tester->>WaiverList: Check if test_auto_dtype_tp8 is skipped
WaiverList-->>Tester: No longer skipped (entry removed)
Tester->>Tester: Run test_auto_dtype_tp8 as part of integration suite
Estimated code review effort🎯 2 (Simple) | ⏱️ ~7 minutes Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (1)
💤 Files with no reviewable changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
✨ Finishing Touches🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes weight loading for NemotronNAS models when tensor parallelization (TP) is greater than 1, specifically addressing an issue where KV heads with fewer heads than TP size weren't being properly duplicated.
- Fixes KV head duplication logic by computing heads from weight shapes instead of using a flawed list-based approach
- Moves Qwen3MOE-specific logic into its dedicated weight mapper class
- Corrects type hints for BaseWeightMapper configuration
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
tests/integration/test_lists/waives.txt | Removes test skip for NemotronNas TP8 test since it now passes |
tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py | Simplifies KV weight duplication by computing heads from weight shape |
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_moe_weight_mapper.py | Adds Qwen3MOE-specific KV heads initialization logic |
tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py | Updates type hints and replaces num_kv_heads with head_dim calculation |
Signed-off-by: Tomer Asida <[email protected]>
/bot run |
PR_Github #13364 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Short and sweet.
One little ask, approved anyway
PR_Github #13364 [ run ] completed with state |
Signed-off-by: Tomer Asida <[email protected]>
/bot run |
PR_Github #13373 [ run ] triggered by Bot |
PR_Github #13373 [ run ] completed with state |
/bot run |
PR_Github #13481 [ run ] triggered by Bot |
PR_Github #13481 [ run ] completed with state |
/bot run --disable-fail-fast |
PR_Github #13491 [ run ] triggered by Bot |
PR_Github #13491 [ run ] completed with state |
Signed-off-by: Tomer Asida <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Description
This PR fixes weight loading for NemotronNAS with TP>1.
The issue was in layers where the number of KV heads is smaller than TP size. In this case, KV heads should be duplicated to match the TP size, so each TP rank can hold a full KV head. The logic to call
_duplicate_kv
where_num_kv_heads
is a list was flawed.The fix is to compute the number of KV heads from the shape of the K or V weight and the head dimension. The list logic that was present in
_duplicate_kv_weights
was added in a PR relevant only for Qwen3MOE, so it is now encapsulated only inQwen3MoeHfWeightMapper
.Other than this, this PR also fixes the type hint for
BaseWeightMapper._config
.Test Coverage
test_llm_api_pytorch.TestNemotronNas.test_auto_dtype_tp8
now passes