Skip to content

Commit 98d10df

Browse files
tomeras91lancelly
authored andcommitted
[nvbug 5380101][fix] Fix nemotronNAS loading for TP>1 (NVIDIA#6447)
Signed-off-by: Tomer Asida <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 0cb1a69 commit 98d10df

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from torch import nn
55

6-
from tensorrt_llm._torch.model_config import ModelConfig, TConfig
6+
from tensorrt_llm._torch.model_config import ModelConfig
77
from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM
88

99

@@ -14,11 +14,11 @@ def __init__(self):
1414
self._mapping: dict = {}
1515
self._skip_modules = []
1616
self._model: Union[nn.Module, DecoderModelForCausalLM] | None = None
17-
self._config: TConfig | None = None
17+
self._config: ModelConfig | None = None
1818

1919
def init_model_and_config(self, model: Union[nn.Module,
2020
DecoderModelForCausalLM],
21-
config: TConfig):
21+
config: ModelConfig):
2222
self._model = model
2323
self._config = config
2424

@@ -29,9 +29,9 @@ def init_model_and_config(self, model: Union[nn.Module,
2929
raise ValueError("model must have a config attribute")
3030

3131
self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
32-
self._num_kv_heads = model.config.num_key_value_heads if hasattr(
33-
model.config, 'num_key_value_heads'
34-
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
32+
self._head_dim = model.config.head_dim if hasattr(
33+
model.config, 'head_dim'
34+
) and model.config.head_dim is not None else model.config.hidden_size // model.config.num_attention_heads
3535

3636
self.map_weights()
3737

@@ -153,7 +153,7 @@ def mapping(self) -> dict:
153153
return self._mapping
154154

155155
@property
156-
def config(self) -> TConfig:
156+
def config(self) -> ModelConfig:
157157
if self._config is None:
158158
raise RuntimeError("Weight mapper is not initialized")
159159
return self._config

tensorrt_llm/_torch/models/checkpoints/hf/qwen3_moe_weight_mapper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
1+
from typing import Union
2+
13
from torch import nn
24

5+
from tensorrt_llm._torch.model_config import ModelConfig
36
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
47
Qwen2MoeHfWeightMapper
58
from tensorrt_llm._torch.models.modeling_utils import register_mapper
9+
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM
610

711

812
@register_mapper("HF", "Qwen3MoeForCausalLM")
913
class Qwen3MoeHfWeightMapper(Qwen2MoeHfWeightMapper):
1014

15+
def init_model_and_config(self, model: Union[nn.Module,
16+
DecoderModelForCausalLM],
17+
config: ModelConfig):
18+
super().init_model_and_config(model, config)
19+
self._num_kv_heads = model.config.num_key_value_heads if hasattr(
20+
model.config, 'num_key_value_heads'
21+
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
22+
1123
def should_skip_module(self, module_name: str) -> bool:
1224
if module_name.startswith("draft_model"):
1325
return True

tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,15 @@ def should_skip_module(self, module_name: str) -> bool:
5959
def _duplicate_kv_weights(self, module: nn.Module, new_name: str,
6060
weights: dict):
6161
if new_name in ['k_proj', 'v_proj']:
62-
num_kv_heads_list = [self._num_kv_heads
63-
] * len(weights) if isinstance(
64-
self._num_kv_heads,
65-
int) else self._num_kv_heads
62+
# k_proj and v_proj shape is [num_kv_heads*head_dim, hidden_dim]
63+
num_kv_heads = weights['weight'].shape[0] // self._head_dim
6664
processed_weights = {
6765
k:
6866
self._duplicate_kv(weight=v[:],
69-
num_kv_heads=num_kv_heads_list[i],
67+
num_kv_heads=num_kv_heads,
7068
tensor_parallel_size=self._tp_size)
7169
if k in ["weight", "bias"] else v
72-
for i, (k, v) in enumerate(weights.items())
70+
for k, v in weights.items()
7371
}
7472
return processed_weights
7573

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,6 @@ full:GH200/disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_l
400400
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620)
401401
test_e2e.py::test_ptp_quickstart_advanced[Mixtral-8x7B-NVFP4-nvfp4-quantized/Mixtral-8x7B-Instruct-v0.1] SKIP (https://nvbugs/5377465)
402402
test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-70B-FP8-llama-3.1-model/Llama-3.1-70B-Instruct-FP8] SKIP (https://nvbugs/5377465)
403-
accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8 SKIP (https://nvbugs/5380101)
404403
test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-405B-FP8-llama-3.1-model/Llama-3.1-405B-Instruct-FP8] SKIP (https://nvbugs/5380570)
405404
test_e2e.py::test_ptp_quickstart_advanced_8gpus[Nemotron-Ultra-253B-nemotron-nas/Llama-3_1-Nemotron-Ultra-253B-v1] SKIP (https://nvbugs/5380570)
406405
examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5385987)

0 commit comments

Comments
 (0)