diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index b5ad4f45203..77573920265 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -243,23 +243,21 @@ def forward( @staticmethod def lora_config(model_dir: str): _lora_config = LoraConfig( - lora_dir=[ - f"{model_dir}/vision-lora", - f"{model_dir}/speech-lora", - ], lora_target_modules=[ "attn_qkv", "attn_dense", - "mlp_h_to_4h", + "mlp_gate_up", "mlp_4h_to_h", ], trtllm_modules_to_hf_modules={ "attn_qkv": "qkv_proj", "attn_dense": "o_proj", - "mlp_h_to_4h": "gate_up_proj", + "mlp_gate_up": "gate_up_proj", "mlp_4h_to_h": "down_proj", }, max_lora_rank=320, # Max rank for Phi4MM. + swap_gate_up_proj_lora_b_weight= + False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM. ) return _lora_config diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 37f8e0410b0..5f635f47697 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -509,7 +509,8 @@ def create_py_executor_instance( resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager model_engine.set_lora_model_config( lora_config.lora_target_modules, - lora_config.trtllm_modules_to_hf_modules) + lora_config.trtllm_modules_to_hf_modules, + lora_config.swap_gate_up_proj_lora_b_weight) max_num_sequences = executor_config.max_batch_size * mapping.pp_size diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 4eaedf37d92..2bbd97a3821 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -438,13 +438,16 @@ def __init__( else: self.cache_indirection_attention = None - def set_lora_model_config(self, lora_target_modules: list[str], - trtllm_modules_to_hf_modules: dict[str, str]): + def set_lora_model_config(self, + lora_target_modules: list[str], + trtllm_modules_to_hf_modules: dict[str, str], + swap_gate_up_proj_lora_b_weight: bool = True): self.lora_model_config = LoraModelConfig( lora_target_modules=lora_target_modules, trtllm_modules_to_hf_modules=trtllm_modules_to_hf_modules, hidden_size=self.model.config.hidden_size, - dtype=torch_dtype_to_str(self.model.config.torch_dtype)) + dtype=torch_dtype_to_str(self.model.config.torch_dtype), + swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight) @property def use_mrope(self): diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index eb33f8aa5b9..5b1dd3e2091 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1206,7 +1206,8 @@ def __init__(self, self._lora_model_config = LoraModelConfig( lora_config.lora_target_modules, lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size, - binding_to_str_dtype(model_config.data_type)) + binding_to_str_dtype(model_config.data_type), + lora_config.swap_gate_up_proj_lora_b_weight) self._lora_manager = LoraManager() def add_request_peft(self, request: LlmRequest): diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index f2e32047162..a8d03496796 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -241,6 +241,7 @@ class LoraConfig(DictConversion): trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) max_loras: int | None = None max_cpu_loras: int | None = None + swap_gate_up_proj_lora_b_weight: bool = True def __post_init__(self): assert self.lora_ckpt_source in ["hf", "nemo"], ( @@ -258,6 +259,7 @@ class LoraModelConfig: trtllm_modules_to_hf_modules: dict[str, str] hidden_size: int dtype: str + swap_gate_up_proj_lora_b_weight: bool = True class HfLoraLoader: @@ -1026,16 +1028,17 @@ def load_from_hf( ) hf_modules = set(hf_modules_to_trtllm_modules.keys()) - def preprocess_lora_weights(lora_model): + def preprocess_lora_weights(lora_model, model_config): # Swap weights of gate_up_proj - for key, value in lora_model.items(): - if "gate_up_proj.lora_B.weight" in key: - original_weights = value.contiguous().clone() - half_split = original_weights.shape[0] // 2 - first_half = original_weights[:half_split, :] - second_half = original_weights[half_split:, :] - value = torch.cat((second_half, first_half), dim=0) - lora_model[key] = value + if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True): + for key, value in lora_model.items(): + if "gate_up_proj.lora_B.weight" in key: + original_weights = value.contiguous().clone() + half_split = original_weights.shape[0] // 2 + first_half = original_weights[:half_split, :] + second_half = original_weights[half_split:, :] + value = torch.cat((second_half, first_half), dim=0) + lora_model[key] = value return lora_model def load_from_model_dir(uid, model_dir, hf_config): @@ -1047,7 +1050,7 @@ def load_from_model_dir(uid, model_dir, hf_config): lora_model = load_state_dict(get_model_path(model_dir, "adapter_model")) if lora_model is None: raise ValueError(f"Failed to load adapter_model from {model_dir}") - lora_model = preprocess_lora_weights(lora_model) + lora_model = preprocess_lora_weights(lora_model, model_config) all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component) rank = int(hf_config["r"]) rs_lora = bool(hf_config.get("use_rslora", False)) diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index 822a2a30c3a..49915d3b479 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -198,15 +198,17 @@ def get_model_yaml_config(model_label: str, } if 'phi_4_multimodal_instruct' in model_label: lora_config['lora_config']['lora_target_modules'] = [ - "attn_qkv", "attn_dense", "mlp_h_to_4h", "mlp_4h_to_h" + "attn_qkv", "attn_dense", "mlp_gate_up", "mlp_4h_to_h" ] lora_config['lora_config']['trtllm_modules_to_hf_modules'] = { "attn_qkv": "qkv_proj", "attn_dense": "o_proj", - "mlp_h_to_4h": "gate_up_proj", + "mlp_gate_up": "gate_up_proj", "mlp_4h_to_h": "down_proj" } lora_config['lora_config']['max_lora_rank'] = 320 + lora_config['lora_config'][ + 'swap_gate_up_proj_lora_b_weight'] = False base_config.update(lora_config) kv_cache_config = base_config.get('kv_cache_config', {}) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 0d47cc7c6f7..95be7fb4c12 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -2486,15 +2486,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality): } expected_keywords = { "image": [ - ["image", "depicts", "mountain", "half", "rock"], - ["road", "car", "lane", "traffic", "bus"], + ["object", "mountain", "weather", "clear", "clouds"], + ["traffic", "road", "vehicles", "cars", "bus"], ], "audio": [ ["what", "is", "the", "traffic", "sign", "in", "image"], ["what", "is", "shown", "in", "this", "image"], ], "image_audio": [ - ["image", "depicts", "Grand", "rock", "scene"], + ["image", "depicts", "scenic", "famous", "landmark"], ], }