Skip to content

Commit abf1402

Browse files
committed
First draft
Signed-off-by: Michal Guzek <[email protected]>
1 parent 55f4f2d commit abf1402

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,20 @@ def get_bindings_model_config(self,
442442

443443
mlp_hidden_size = None
444444
if self.pretrained_config.intermediate_size is not None:
445-
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
445+
if isinstance(self.pretrained_config.intermediate_size,
446+
(list, tuple)):
447+
# Per-layer MLP dimensions (e.g., Nemotron-NAS, variable MLP models)
448+
mlp_hidden_size_per_layer = [
449+
intermediate_size // self.mapping.tp_size
450+
for intermediate_size in
451+
self.pretrained_config.intermediate_size
452+
]
453+
model_config_cpp.mlp_hidden_size_per_layer = mlp_hidden_size_per_layer
454+
# For LoRA compatibility, use the maximum MLP dimension
455+
mlp_hidden_size = max(mlp_hidden_size_per_layer)
456+
else:
457+
# Uniform MLP dimensions across all layers
458+
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
446459
else:
447460
# TODO: once tensorrt_llm._torch.AutoConfig is implemented, the following logic
448461
# should be moved to tensorrt_llm._torch.AutoConfig of the relevant modeling_xxx file

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,10 +467,22 @@ def create_py_executor_instance(
467467
# all layers have the same number of KV heads
468468
num_kv_attention_heads = num_kv_attention_heads_per_layer[0]
469469

470+
mlp_hidden_size_per_layer = model_binding_config.mlp_hidden_size_per_layer
471+
if mlp_hidden_size_per_layer and max(mlp_hidden_size_per_layer) != min(
472+
mlp_hidden_size_per_layer):
473+
logger.warning(
474+
"Defining LORA with per-layer MLP dimensions is not supported for LORA, using the max MLP hidden size per layer"
475+
)
476+
mlp_hidden_size = max(mlp_hidden_size_per_layer)
477+
else:
478+
# all layers have the same MLP hidden size
479+
mlp_hidden_size = mlp_hidden_size_per_layer[0]
480+
481+
# THEN UPDATE THE LoraModule.create_lora_modules CALL:
470482
lora_modules = LoraModule.create_lora_modules(
471483
lora_module_names=lora_config.lora_target_modules,
472484
hidden_size=model_binding_config.hidden_size,
473-
mlp_hidden_size=model_binding_config.mlp_hidden_size,
485+
mlp_hidden_size=mlp_hidden_size,
474486
num_attention_heads=model_binding_config.num_heads,
475487
num_kv_attention_heads=num_kv_attention_heads,
476488
attention_head_size=model_binding_config.head_size,

tests/integration/defs/examples/test_nemotron_nas.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from pathlib import Path
22

3+
import defs.ci_profiler
34
import pytest
45
from defs.common import convert_weights, venv_check_call, venv_mpi_check_call
56
from defs.conftest import get_device_memory, get_sm_version
67
from defs.trt_test_alternative import check_call
78

9+
from tensorrt_llm import LLM
10+
from tensorrt_llm.executor.request import LoRARequest
11+
from tensorrt_llm.lora_manager import LoraConfig
12+
from tensorrt_llm.sampling_params import SamplingParams
13+
814
# skip trt flow cases on post-Blackwell-Ultra
915
if get_sm_version() >= 103:
1016
pytest.skip(
@@ -122,3 +128,71 @@ def test_nemotron_nas_summary_2gpu(nemotron_nas_example_root, llm_venv,
122128
]
123129

124130
venv_mpi_check_call(llm_venv, mpi_cmd, summary_cmd)
131+
132+
133+
@pytest.mark.skip_less_device(4)
134+
@pytest.mark.skip_less_device_memory(80000)
135+
@pytest.mark.parametrize("nemotron_nas_model_root", [
136+
"Llama-3_3-Nemotron-Super-49B-v1",
137+
],
138+
indirect=True)
139+
def test_nemotron_super_49b_real_lora_torch(nemotron_nas_example_root, llm_venv,
140+
nemotron_nas_model_root,
141+
llm_datasets_root, llm_rouge_root,
142+
engine_dir, cmodel_dir):
143+
"""Run Nemotron Super 49B with real LoRA adapters using LLM-API Torch backend."""
144+
145+
print("Testing Nemotron Super 49B with real LoRA adapters...")
146+
147+
lora_adapter_path = f"/code/tensorrt_llm/llama-3.3-nemotron-super-49b-v1/llama-3.3-nemotron-super-49b-v1_vlora-1a2cb80-v2"
148+
print(f"Using real LoRA from: {lora_adapter_path}")
149+
150+
defs.ci_profiler.start("test_nemotron_real_lora_torch")
151+
152+
lora_config = LoraConfig(
153+
lora_dir=[lora_adapter_path],
154+
max_lora_rank=32, # From adapter_config.json: "r": 32
155+
max_loras=1,
156+
max_cpu_loras=1,
157+
)
158+
159+
with LLM(model=nemotron_nas_model_root,
160+
lora_config=lora_config,
161+
tensor_parallel_size=4,
162+
dtype="bfloat16",
163+
max_batch_size=2,
164+
max_input_len=512,
165+
max_seq_len=1024,
166+
max_beam_width=1) as llm:
167+
168+
prompts = [
169+
"What is the capital of France?",
170+
"Explain quantum computing in simple terms."
171+
]
172+
173+
sampling_params = SamplingParams(max_tokens=50,
174+
temperature=0.7,
175+
top_p=0.9)
176+
177+
lora_request = [LoRARequest("nemotron-lora", 0, lora_adapter_path)]
178+
179+
print("Running inference with real LoRA adapter...")
180+
outputs = llm.generate(prompts,
181+
sampling_params,
182+
lora_request=lora_request)
183+
184+
for i, output in enumerate(outputs):
185+
print(f"Prompt {i+1}: {prompts[i]}")
186+
print(f"Response {i+1}: {output.outputs[0].text}")
187+
print("-" * 50)
188+
189+
assert len(outputs) == 2
190+
assert len(outputs[0].outputs) > 0
191+
assert len(outputs[1].outputs) > 0
192+
assert len(outputs[0].outputs[0].text) > 0
193+
assert len(outputs[1].outputs[0].text) > 0
194+
195+
defs.ci_profiler.stop("test_nemotron_real_lora_torch")
196+
print(
197+
f"test_nemotron_real_lora_torch: {defs.ci_profiler.elapsed_time_in_sec('test_nemotron_real_lora_torch')} sec"
198+
)

0 commit comments

Comments
 (0)