Skip to content

Commit 0c42f54

Browse files
authored
Bugfix/fix nemotron nas lora support (#6380)
Signed-off-by: Shahar Mor <[email protected]>
1 parent baece56 commit 0c42f54

File tree

3 files changed

+18
-54
lines changed

3 files changed

+18
-54
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -299,48 +299,6 @@ def get_bindings_model_config(self,
299299
num_heads = self.pretrained_config.num_attention_heads // (
300300
self.mapping.tp_size * self.mapping.cp_size)
301301

302-
# Handle both uniform and per-layer KV heads
303-
num_kv_heads_per_layer = getattr(self.pretrained_config,
304-
'num_kv_heads_per_layer', None)
305-
if num_kv_heads_per_layer is not None:
306-
# For models with per-layer KV heads, like nemotron-nas
307-
kv_heads_per_layer_raw = num_kv_heads_per_layer
308-
use_per_layer_kv_heads = True
309-
else:
310-
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
311-
num_kv_heads_raw = getattr(self.pretrained_config,
312-
'num_key_value_heads', None)
313-
314-
if num_kv_heads_raw is not None and isinstance(
315-
num_kv_heads_raw, list):
316-
# num_key_value_heads is a list - treat as per-layer KV heads
317-
kv_heads_per_layer_raw = num_kv_heads_raw
318-
use_per_layer_kv_heads = True
319-
else:
320-
# num_key_value_heads is scalar or None - treat as uniform KV heads
321-
if num_kv_heads_raw is None:
322-
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
323-
num_kv_heads_raw = getattr(
324-
self.pretrained_config, 'num_query_groups',
325-
self.pretrained_config.num_attention_heads)
326-
327-
num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
328-
self.mapping.cp_size)
329-
use_per_layer_kv_heads = False
330-
331-
if use_per_layer_kv_heads:
332-
# TRT-LLM LoRA requires uniform KV heads across layers
333-
if self.lora_config is not None and len(
334-
set(kv_heads_per_layer_raw)) > 1:
335-
raise ValueError(
336-
f"TRT-LLM LoRA requires uniform KV heads across layers, "
337-
f"got: {kv_heads_per_layer_raw}")
338-
# Apply TP/CP scaling to each layer
339-
num_kv_heads_per_layer = [
340-
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
341-
for kv_heads in kv_heads_per_layer_raw
342-
]
343-
344302
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
345303

346304
model_config_cpp = ModelConfigCpp(
@@ -361,9 +319,18 @@ def get_bindings_model_config(self,
361319
else:
362320
model_config_cpp.tokens_per_block = tokens_per_block
363321

364-
if use_per_layer_kv_heads:
322+
num_key_value_heads = getattr(self.pretrained_config,
323+
"num_key_value_heads", num_heads)
324+
if isinstance(num_key_value_heads, (list, tuple)):
325+
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
326+
num_kv_heads_per_layer = [
327+
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
328+
for kv_heads in num_key_value_heads
329+
]
365330
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
366331
else:
332+
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
333+
self.mapping.cp_size)
367334
model_config_cpp.set_num_kv_heads(num_kv_heads)
368335

369336
mlp_hidden_size = None

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -451,18 +451,16 @@ def create_py_executor_instance(
451451

452452
num_experts = _try_infer_num_experts(model_engine.model.model_config)
453453

454-
num_attn_layers = model_binding_config.num_attention_layers()
455-
per_layer_kv_heads = [
456-
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
457-
]
458-
num_kv_attention_heads = max(per_layer_kv_heads)
459-
if len(set(per_layer_kv_heads)) > 1:
460-
# NOTE: This code-path is currently untested and not validated. Can fail!
461-
# This support is tracked in TRTLLM-6561
454+
num_kv_attention_heads_per_layer = model_binding_config.num_kv_heads_per_layer
455+
if max(num_kv_attention_heads_per_layer) != min(
456+
num_kv_attention_heads_per_layer):
462457
logger.warning(
463-
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
464-
"This code-path is currently untested and not validated. May fail!"
458+
"Defining LORA with per-layer KV heads is not supported for LORA, using the max number of KV heads per layer"
465459
)
460+
num_kv_attention_heads = max(num_kv_attention_heads_per_layer)
461+
else:
462+
# all layers have the same number of KV heads
463+
num_kv_attention_heads = num_kv_attention_heads_per_layer[0]
466464

467465
lora_modules = LoraModule.create_lora_modules(
468466
lora_module_names=lora_config.lora_target_modules,

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():
350350

351351
# TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high
352352
# https://jirasw.nvidia.com/browse/TRTLLM-5045
353-
@pytest.mark.skip(reason="https://nvbugs/5401210")
354353
@skip_gpu_memory_less_than_138gb
355354
def test_nemotron_nas_lora() -> None:
356355
lora_config = LoraConfig(lora_dir=[

0 commit comments

Comments
 (0)