Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions tensorrt_llm/_torch/models/modeling_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,23 +243,21 @@ def forward(
@staticmethod
def lora_config(model_dir: str):
_lora_config = LoraConfig(
lora_dir=[
f"{model_dir}/vision-lora",
f"{model_dir}/speech-lora",
],
lora_target_modules=[
"attn_qkv",
"attn_dense",
"mlp_h_to_4h",
"mlp_gate_up",
"mlp_4h_to_h",
],
trtllm_modules_to_hf_modules={
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_up_proj",
"mlp_gate_up": "gate_up_proj",
"mlp_4h_to_h": "down_proj",
},
max_lora_rank=320, # Max rank for Phi4MM.
swap_gate_up_proj_lora_b_weight=
False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM.
)
return _lora_config

Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,8 @@ def create_py_executor_instance(
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
model_engine.set_lora_model_config(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules)
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)

max_num_sequences = executor_config.max_batch_size * mapping.pp_size

Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,16 @@ def __init__(
else:
self.cache_indirection_attention = None

def set_lora_model_config(self, lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str]):
def set_lora_model_config(self,
lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str],
swap_gate_up_proj_lora_b_weight: bool = True):
self.lora_model_config = LoraModelConfig(
lora_target_modules=lora_target_modules,
trtllm_modules_to_hf_modules=trtllm_modules_to_hf_modules,
hidden_size=self.model.config.hidden_size,
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)

@property
def use_mrope(self):
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,8 @@ def __init__(self,
self._lora_model_config = LoraModelConfig(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
binding_to_str_dtype(model_config.data_type))
binding_to_str_dtype(model_config.data_type),
lora_config.swap_gate_up_proj_lora_b_weight)
self._lora_manager = LoraManager()

def add_request_peft(self, request: LlmRequest):
Expand Down
23 changes: 13 additions & 10 deletions tensorrt_llm/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class LoraConfig(DictConversion):
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: int | None = None
max_cpu_loras: int | None = None
swap_gate_up_proj_lora_b_weight: bool = True

def __post_init__(self):
assert self.lora_ckpt_source in ["hf", "nemo"], (
Expand All @@ -258,6 +259,7 @@ class LoraModelConfig:
trtllm_modules_to_hf_modules: dict[str, str]
hidden_size: int
dtype: str
swap_gate_up_proj_lora_b_weight: bool = True


class HfLoraLoader:
Expand Down Expand Up @@ -1026,16 +1028,17 @@ def load_from_hf(
)
hf_modules = set(hf_modules_to_trtllm_modules.keys())

def preprocess_lora_weights(lora_model):
def preprocess_lora_weights(lora_model, model_config):
# Swap weights of gate_up_proj
for key, value in lora_model.items():
if "gate_up_proj.lora_B.weight" in key:
original_weights = value.contiguous().clone()
half_split = original_weights.shape[0] // 2
first_half = original_weights[:half_split, :]
second_half = original_weights[half_split:, :]
value = torch.cat((second_half, first_half), dim=0)
lora_model[key] = value
if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True):
for key, value in lora_model.items():
if "gate_up_proj.lora_B.weight" in key:
original_weights = value.contiguous().clone()
half_split = original_weights.shape[0] // 2
first_half = original_weights[:half_split, :]
second_half = original_weights[half_split:, :]
value = torch.cat((second_half, first_half), dim=0)
lora_model[key] = value
return lora_model

def load_from_model_dir(uid, model_dir, hf_config):
Expand All @@ -1047,7 +1050,7 @@ def load_from_model_dir(uid, model_dir, hf_config):
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
if lora_model is None:
raise ValueError(f"Failed to load adapter_model from {model_dir}")
lora_model = preprocess_lora_weights(lora_model)
lora_model = preprocess_lora_weights(lora_model, model_config)
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
rank = int(hf_config["r"])
rs_lora = bool(hf_config.get("use_rslora", False))
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/defs/perf/pytorch_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,17 @@ def get_model_yaml_config(model_label: str,
}
if 'phi_4_multimodal_instruct' in model_label:
lora_config['lora_config']['lora_target_modules'] = [
"attn_qkv", "attn_dense", "mlp_h_to_4h", "mlp_4h_to_h"
"attn_qkv", "attn_dense", "mlp_gate_up", "mlp_4h_to_h"
]
lora_config['lora_config']['trtllm_modules_to_hf_modules'] = {
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_up_proj",
"mlp_gate_up": "gate_up_proj",
"mlp_4h_to_h": "down_proj"
}
lora_config['lora_config']['max_lora_rank'] = 320
lora_config['lora_config'][
'swap_gate_up_proj_lora_b_weight'] = False
base_config.update(lora_config)

kv_cache_config = base_config.get('kv_cache_config', {})
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2486,15 +2486,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
}
expected_keywords = {
"image": [
["image", "depicts", "mountain", "half", "rock"],
["road", "car", "lane", "traffic", "bus"],
["object", "mountain", "weather", "clear", "clouds"],
["traffic", "road", "vehicles", "cars", "bus"],
],
"audio": [
["what", "is", "the", "traffic", "sign", "in", "image"],
["what", "is", "shown", "in", "this", "image"],
],
"image_audio": [
["image", "depicts", "Grand", "rock", "scene"],
["image", "depicts", "scenic", "famous", "landmark"],
],
}

Expand Down