Skip to content

Commit f4adb25

Browse files
committed
[TRTLLM-6825][fix] Update lora for phi4-mm (#6817)
Signed-off-by: Wanli Jiang <[email protected]>
1 parent 1e5a6be commit f4adb25

File tree

8 files changed

+134
-25
lines changed

8 files changed

+134
-25
lines changed

tensorrt_llm/_torch/models/modeling_phi4mm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,23 +243,21 @@ def forward(
243243
@staticmethod
244244
def lora_config(model_dir: str):
245245
_lora_config = LoraConfig(
246-
lora_dir=[
247-
f"{model_dir}/vision-lora",
248-
f"{model_dir}/speech-lora",
249-
],
250246
lora_target_modules=[
251247
"attn_qkv",
252248
"attn_dense",
253-
"mlp_h_to_4h",
249+
"mlp_gate_up",
254250
"mlp_4h_to_h",
255251
],
256252
trtllm_modules_to_hf_modules={
257253
"attn_qkv": "qkv_proj",
258254
"attn_dense": "o_proj",
259-
"mlp_h_to_4h": "gate_up_proj",
255+
"mlp_gate_up": "gate_up_proj",
260256
"mlp_4h_to_h": "down_proj",
261257
},
262258
max_lora_rank=320, # Max rank for Phi4MM.
259+
swap_gate_up_proj_lora_b_weight=
260+
False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM.
263261
)
264262
return _lora_config
265263

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,8 @@ def create_py_executor_instance(
509509
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
510510
model_engine.set_lora_model_config(
511511
lora_config.lora_target_modules,
512-
lora_config.trtllm_modules_to_hf_modules)
512+
lora_config.trtllm_modules_to_hf_modules,
513+
lora_config.swap_gate_up_proj_lora_b_weight)
513514

514515
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
515516

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,14 @@ def __init__(
438438
self.cache_indirection_attention = None
439439

440440
def set_lora_model_config(self, lora_target_modules: list[str],
441-
trtllm_modules_to_hf_modules: dict[str, str]):
441+
trtllm_modules_to_hf_modules: dict[str, str],
442+
swap_gate_up_proj_lora_b_weight: bool = True):
442443
self.lora_model_config = LoraModelConfig(
443444
lora_target_modules=lora_target_modules,
444445
trtllm_modules_to_hf_modules=trtllm_modules_to_hf_modules,
445446
hidden_size=self.model.config.hidden_size,
446-
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
447+
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
448+
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)
447449

448450
@property
449451
def use_mrope(self):

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,8 @@ def __init__(self,
12061206
self._lora_model_config = LoraModelConfig(
12071207
lora_config.lora_target_modules,
12081208
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
1209-
binding_to_str_dtype(model_config.data_type))
1209+
binding_to_str_dtype(model_config.data_type),
1210+
lora_config.swap_gate_up_proj_lora_b_weight)
12101211
self._lora_manager = LoraManager()
12111212

12121213
def add_request_peft(self, request: LlmRequest):

tensorrt_llm/lora_helper.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from dataclasses import dataclass, field
17+
from typing import Dict, List, Optional
18+
19+
from ._utils import DictConversion
20+
21+
22+
def get_missing_qkv_modules_from_lora_modules(
23+
lora_target_modules: List[str]) -> List[str]:
24+
"""Get missing QKV modules from LoRA target modules.
25+
26+
In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
27+
all disabled at the same time. However, some lora checkpoints (e.g. BART) only contain two of them,
28+
so we use zero tensor to fill the missing ones.
29+
"""
30+
missing_qkv_modules = []
31+
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
32+
for lora_module in ["attn_q", "attn_k", "attn_v"]:
33+
if lora_module not in lora_target_modules:
34+
missing_qkv_modules.append(lora_module)
35+
if any(x in lora_target_modules
36+
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
37+
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
38+
if lora_module not in lora_target_modules:
39+
missing_qkv_modules.append(lora_module)
40+
return missing_qkv_modules
41+
42+
43+
def get_default_trtllm_modules_to_hf_modules():
44+
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
45+
return {
46+
"attn_q": "q_proj",
47+
"attn_k": "k_proj",
48+
"attn_v": "v_proj",
49+
"attn_dense": "o_proj",
50+
"mlp_h_to_4h": "gate_proj",
51+
"mlp_4h_to_h": "down_proj",
52+
"mlp_gate": "up_proj",
53+
"mlp_gate_up": "gate_up_proj",
54+
"moe_h_to_4h": "w1",
55+
"moe_4h_to_h": "w2",
56+
"moe_gate": "w3",
57+
"moe_router": "gate",
58+
}
59+
60+
61+
def use_lora(
62+
model,
63+
lora_config: "LoraConfig",
64+
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
65+
):
66+
"""Use LoRA with the given model and configuration.
67+
68+
This function is a wrapper that delegates to the appropriate loading function
69+
based on the LoRA checkpoint source.
70+
"""
71+
if lora_config.lora_ckpt_source == "nemo":
72+
from .lora_manager import load_nemo_lora
73+
load_nemo_lora(model, lora_config)
74+
elif lora_config.lora_ckpt_source == "hf":
75+
from .lora_manager import load_hf_lora
76+
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
77+
else:
78+
raise ValueError(
79+
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
80+
81+
82+
@dataclass
83+
class LoraConfig(DictConversion):
84+
lora_dir: List[str] = field(default_factory=list)
85+
lora_ckpt_source: str = "hf"
86+
max_lora_rank: int = 64
87+
lora_target_modules: List[str] = field(default_factory=list)
88+
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
89+
max_loras: Optional[int] = None
90+
max_cpu_loras: Optional[int] = None
91+
swap_gate_up_proj_lora_b_weight: bool = True
92+
93+
def __post_init__(self):
94+
assert self.lora_ckpt_source in [
95+
"hf", "nemo"
96+
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
97+
)
98+
99+
@property
100+
def missing_qkv_modules(self) -> List[str]:
101+
return get_missing_qkv_modules_from_lora_modules(
102+
self.lora_target_modules)

tensorrt_llm/lora_manager.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ class LoraConfig(DictConversion):
241241
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
242242
max_loras: int | None = None
243243
max_cpu_loras: int | None = None
244+
swap_gate_up_proj_lora_b_weight: bool = True
244245

245246
def __post_init__(self):
246247
assert self.lora_ckpt_source in ["hf", "nemo"], (
@@ -258,6 +259,7 @@ class LoraModelConfig:
258259
trtllm_modules_to_hf_modules: dict[str, str]
259260
hidden_size: int
260261
dtype: str
262+
swap_gate_up_proj_lora_b_weight: bool = True
261263

262264

263265
class HfLoraLoader:
@@ -1026,16 +1028,17 @@ def load_from_hf(
10261028
)
10271029
hf_modules = set(hf_modules_to_trtllm_modules.keys())
10281030

1029-
def preprocess_lora_weights(lora_model):
1031+
def preprocess_lora_weights(lora_model, model_config):
10301032
# Swap weights of gate_up_proj
1031-
for key, value in lora_model.items():
1032-
if "gate_up_proj.lora_B.weight" in key:
1033-
original_weights = value.contiguous().clone()
1034-
half_split = original_weights.shape[0] // 2
1035-
first_half = original_weights[:half_split, :]
1036-
second_half = original_weights[half_split:, :]
1037-
value = torch.cat((second_half, first_half), dim=0)
1038-
lora_model[key] = value
1033+
if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True):
1034+
for key, value in lora_model.items():
1035+
if "gate_up_proj.lora_B.weight" in key:
1036+
original_weights = value.contiguous().clone()
1037+
half_split = original_weights.shape[0] // 2
1038+
first_half = original_weights[:half_split, :]
1039+
second_half = original_weights[half_split:, :]
1040+
value = torch.cat((second_half, first_half), dim=0)
1041+
lora_model[key] = value
10391042
return lora_model
10401043

10411044
def load_from_model_dir(uid, model_dir, hf_config):
@@ -1047,7 +1050,7 @@ def load_from_model_dir(uid, model_dir, hf_config):
10471050
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
10481051
if lora_model is None:
10491052
raise ValueError(f"Failed to load adapter_model from {model_dir}")
1050-
lora_model = preprocess_lora_weights(lora_model)
1053+
lora_model = preprocess_lora_weights(lora_model, model_config)
10511054
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
10521055
rank = int(hf_config["r"])
10531056
rs_lora = bool(hf_config.get("use_rslora", False))

tests/integration/defs/perf/pytorch_model_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,17 @@ def get_model_yaml_config(model_label: str,
198198
}
199199
if 'phi_4_multimodal_instruct' in model_label:
200200
lora_config['lora_config']['lora_target_modules'] = [
201-
"attn_qkv", "attn_dense", "mlp_h_to_4h", "mlp_4h_to_h"
201+
"attn_qkv", "attn_dense", "mlp_gate_up", "mlp_4h_to_h"
202202
]
203203
lora_config['lora_config']['trtllm_modules_to_hf_modules'] = {
204204
"attn_qkv": "qkv_proj",
205205
"attn_dense": "o_proj",
206-
"mlp_h_to_4h": "gate_up_proj",
206+
"mlp_gate_up": "gate_up_proj",
207207
"mlp_4h_to_h": "down_proj"
208208
}
209209
lora_config['lora_config']['max_lora_rank'] = 320
210+
lora_config['lora_config'][
211+
'swap_gate_up_proj_lora_b_weight'] = False
210212
base_config.update(lora_config)
211213

212214
kv_cache_config = base_config.get('kv_cache_config', {})

tests/integration/defs/test_e2e.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,15 +2486,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
24862486
}
24872487
expected_keywords = {
24882488
"image": [
2489-
["image", "depicts", "mountain", "half", "rock"],
2490-
["road", "car", "lane", "traffic", "bus"],
2489+
["object", "mountain", "weather", "clear", "clouds"],
2490+
["traffic", "road", "vehicles", "cars", "bus"],
24912491
],
24922492
"audio": [
24932493
["what", "is", "the", "traffic", "sign", "in", "image"],
24942494
["what", "is", "shown", "in", "this", "image"],
24952495
],
24962496
"image_audio": [
2497-
["image", "depicts", "Grand", "rock", "scene"],
2497+
["image", "depicts", "scenic", "famous", "landmark"],
24982498
],
24992499
}
25002500

0 commit comments

Comments
 (0)