Skip to content

Commit 53b19cc

Browse files
[Core] Allow disabling TP sharding for parallel Linear layer (vllm-project#23024)
Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 6432739 commit 53b19cc

File tree

7 files changed

+205
-282
lines changed

7 files changed

+205
-282
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 70 additions & 105 deletions
Large diffs are not rendered by default.

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self, load_config: LoadConfig):
6969
# Store all module names (from transformers) that support
7070
# BNB quantization.
7171
self.target_modules: list[str] = []
72+
self.tp_disabled_modules: list[str] = []
7273
# Store the mapping of expert parameters for MoE models.
7374
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
7475
# mapping weight names from transformers to vllm.
@@ -322,14 +323,24 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
322323
quant_state_dict) -> Generator:
323324
from bitsandbytes.functional import quantize_4bit
324325

325-
tp_size = get_tensor_model_parallel_world_size()
326-
tp_rank = get_tensor_model_parallel_rank()
326+
global_tp_size = get_tensor_model_parallel_world_size()
327+
global_tp_rank = get_tensor_model_parallel_rank()
327328

328329
for (
329330
org_weight_name,
330331
mapped_weight_name,
331332
weight_tensor,
332333
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
334+
335+
# override tp_size and tp_rank if the module has disabled TP
336+
if any(tp_disabled_module in mapped_weight_name
337+
for tp_disabled_module in self.tp_disabled_modules):
338+
tp_size = 1
339+
tp_rank = 0
340+
else:
341+
tp_size = global_tp_size
342+
tp_rank = global_tp_rank
343+
333344
if any(target_module in mapped_weight_name
334345
for target_module in self.target_modules
335346
) and mapped_weight_name.endswith(".weight"):
@@ -418,12 +429,16 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
418429
# Map vllm's names to transformers's names.
419430
rep_name, sub_modules = modules_info
420431
for sub_name in sub_modules:
421-
self.target_modules.append(
422-
name.replace(rep_name, sub_name))
432+
new_name = name.replace(rep_name, sub_name)
433+
self.target_modules.append(new_name)
434+
if module.disable_tp:
435+
self.tp_disabled_modules.append(new_name)
423436
# Add original module name even if the module has stacked map,
424437
# in case model has a mixture of disk-merged and disk-split
425438
# weights with same last name.
426439
self.target_modules.append(name)
440+
if module.disable_tp:
441+
self.tp_disabled_modules.append(name)
427442
elif isinstance(module, FusedMoE) and hasattr(
428443
module.quant_method, "quant_config"):
429444
# TODO: support FusedMoE with prequant and 8bit.

vllm/model_executor/models/deepseek_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from vllm.model_executor.layers.layernorm import RMSNorm
4444
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
4545
MergedColumnParallelLinear,
46-
MergedReplicatedLinear,
4746
ReplicatedLinear,
4847
RowParallelLinear)
4948
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -435,12 +434,13 @@ def __init__(
435434
self.max_position_embeddings = max_position_embeddings
436435

437436
if self.q_lora_rank is not None:
438-
self.fused_qkv_a_proj = MergedReplicatedLinear(
437+
self.fused_qkv_a_proj = MergedColumnParallelLinear(
439438
self.hidden_size,
440439
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
441440
bias=False,
442441
quant_config=quant_config,
443-
prefix=f"{prefix}.fused_qkv_a_proj")
442+
prefix=f"{prefix}.fused_qkv_a_proj",
443+
disable_tp=True)
444444
else:
445445
self.kv_a_proj_with_mqa = ReplicatedLinear(
446446
self.hidden_size,

vllm/model_executor/models/glm4_1v.py

Lines changed: 50 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,10 @@
5151
from vllm.logger import init_logger
5252
from vllm.model_executor import SamplingMetadata
5353
from vllm.model_executor.layers.layernorm import RMSNorm
54-
# yapf: disable
5554
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
5655
MergedColumnParallelLinear,
57-
MergedReplicatedLinear,
5856
QKVParallelLinear,
59-
ReplicatedLinear,
6057
RowParallelLinear)
61-
# yapf: enable
6258
from vllm.model_executor.layers.quantization import QuantizationConfig
6359
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
6460
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -174,20 +170,22 @@ def __init__(
174170
use_data_parallel: bool = False,
175171
):
176172
super().__init__()
177-
cls_gate_up = (MergedReplicatedLinear
178-
if use_data_parallel else MergedColumnParallelLinear)
179-
self.gate_up_proj = cls_gate_up(input_size=in_features,
180-
output_sizes=[hidden_features] * 2,
181-
bias=bias,
182-
quant_config=quant_config,
183-
prefix=f"{prefix}.gate_up_proj")
184-
cls_down = (ReplicatedLinear
185-
if use_data_parallel else RowParallelLinear)
186-
self.down_proj = cls_down(hidden_features,
187-
in_features,
188-
bias=bias,
189-
quant_config=quant_config,
190-
prefix=f"{prefix}.down_proj")
173+
self.gate_up_proj = MergedColumnParallelLinear(
174+
input_size=in_features,
175+
output_sizes=[hidden_features] * 2,
176+
bias=bias,
177+
quant_config=quant_config,
178+
prefix=f"{prefix}.gate_up_proj",
179+
disable_tp=use_data_parallel,
180+
)
181+
self.down_proj = RowParallelLinear(
182+
hidden_features,
183+
in_features,
184+
bias=bias,
185+
quant_config=quant_config,
186+
prefix=f"{prefix}.down_proj",
187+
disable_tp=use_data_parallel,
188+
)
191189
self.act_fn = SiluAndMul()
192190

193191
def forward(self, x: torch.Tensor):
@@ -234,48 +232,32 @@ def __init__(
234232
# Per attention head and per partition values.
235233
self.tp_size = (1 if use_data_parallel else
236234
get_tensor_model_parallel_world_size())
237-
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
235+
self.tp_rank = (0 if use_data_parallel else
236+
parallel_state.get_tensor_model_parallel_rank())
238237
self.hidden_size_per_attention_head = dist_utils.divide(
239238
projection_size, num_heads)
240239
self.num_attention_heads_per_partition = dist_utils.divide(
241240
num_heads, self.tp_size)
242241

243-
if use_data_parallel:
244-
self.qkv = ReplicatedLinear(
245-
input_size=embed_dim,
246-
output_size=3 * projection_size,
247-
bias=False,
248-
quant_config=quant_config,
249-
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
250-
prefix=f"{prefix}.qkv_proj"
251-
if quant_config else f"{prefix}.qkv",
252-
)
253-
self.proj = ReplicatedLinear(
254-
input_size=projection_size,
255-
output_size=embed_dim,
256-
quant_config=quant_config,
257-
prefix=f"{prefix}.proj",
258-
bias=False,
259-
)
260-
else:
261-
self.qkv = QKVParallelLinear(
262-
hidden_size=embed_dim,
263-
head_size=self.hidden_size_per_attention_head,
264-
total_num_heads=num_heads,
265-
total_num_kv_heads=num_heads,
266-
bias=False,
267-
quant_config=quant_config,
268-
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
269-
prefix=f"{prefix}.qkv_proj"
270-
if quant_config else f"{prefix}.qkv",
271-
)
272-
self.proj = RowParallelLinear(
273-
input_size=projection_size,
274-
output_size=embed_dim,
275-
quant_config=quant_config,
276-
prefix=f"{prefix}.proj",
277-
bias=False,
278-
)
242+
self.qkv = QKVParallelLinear(
243+
hidden_size=embed_dim,
244+
head_size=self.hidden_size_per_attention_head,
245+
total_num_heads=num_heads,
246+
total_num_kv_heads=num_heads,
247+
bias=False,
248+
quant_config=quant_config,
249+
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
250+
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
251+
disable_tp=use_data_parallel,
252+
)
253+
self.proj = RowParallelLinear(
254+
input_size=projection_size,
255+
output_size=embed_dim,
256+
quant_config=quant_config,
257+
prefix=f"{prefix}.proj",
258+
bias=False,
259+
disable_tp=use_data_parallel,
260+
)
279261

280262
# Detect attention implementation.
281263
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@@ -494,41 +476,31 @@ def __init__(
494476
) -> None:
495477
super().__init__()
496478
self.hidden_size = d_model
497-
if use_data_parallel:
498-
self.proj = ReplicatedLinear(
499-
input_size=self.hidden_size,
500-
output_size=self.hidden_size,
501-
bias=bias,
502-
quant_config=quant_config,
503-
prefix=f"{prefix}.proj",
504-
)
505-
else:
506-
self.proj = ColumnParallelLinear(
507-
self.hidden_size,
508-
self.hidden_size,
509-
bias=bias,
510-
gather_output=True,
511-
quant_config=quant_config,
512-
prefix=f"{prefix}.proj",
513-
)
479+
self.proj = ColumnParallelLinear(
480+
self.hidden_size,
481+
self.hidden_size,
482+
bias=bias,
483+
gather_output=True,
484+
quant_config=quant_config,
485+
prefix=f"{prefix}.proj",
486+
disable_tp=use_data_parallel,
487+
)
514488
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
515-
cls_gate_up = (MergedReplicatedLinear
516-
if use_data_parallel else MergedColumnParallelLinear)
517-
self.gate_up_proj = cls_gate_up(
489+
self.gate_up_proj = MergedColumnParallelLinear(
518490
input_size=self.hidden_size,
519491
output_sizes=[context_dim] * 2,
520492
bias=bias,
521493
quant_config=quant_config,
522494
prefix=f"{prefix}.gate_up_proj",
495+
disable_tp=use_data_parallel,
523496
)
524-
cls_down = (ReplicatedLinear
525-
if use_data_parallel else RowParallelLinear)
526-
self.down_proj = cls_down(
497+
self.down_proj = RowParallelLinear(
527498
context_dim,
528499
self.hidden_size,
529500
bias=bias,
530501
quant_config=quant_config,
531502
prefix=f"{prefix}.down_proj",
503+
disable_tp=use_data_parallel,
532504
)
533505
self.act_fn = SiluAndMul()
534506
self.extra_activation_func = nn.GELU()

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
# yapf: disable
4949
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
5050
MergedColumnParallelLinear,
51-
MergedReplicatedLinear,
5251
QKVParallelLinear,
5352
ReplicatedLinear,
5453
RowParallelLinear)
@@ -178,22 +177,20 @@ def __init__(self,
178177
prefix: str = "",
179178
use_data_parallel: bool = False):
180179
super().__init__()
181-
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
182-
MergedColumnParallelLinear)
183-
self.gate_up_proj = cls_gate_up_proj(
180+
self.gate_up_proj = MergedColumnParallelLinear(
184181
input_size=in_features,
185182
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
186183
bias=bias,
187184
quant_config=quant_config,
188-
prefix=f"{prefix}.gate_up_proj")
189-
190-
cls_down_proj = (ReplicatedLinear
191-
if use_data_parallel else RowParallelLinear)
192-
self.down_proj = cls_down_proj(hidden_features,
193-
in_features,
194-
bias=bias,
195-
quant_config=quant_config,
196-
prefix=f"{prefix}.down_proj")
185+
prefix=f"{prefix}.gate_up_proj",
186+
disable_tp=use_data_parallel)
187+
188+
self.down_proj = RowParallelLinear(hidden_features,
189+
in_features,
190+
bias=bias,
191+
quant_config=quant_config,
192+
prefix=f"{prefix}.down_proj",
193+
disable_tp=use_data_parallel)
197194
self.act_fn = act_fn
198195

199196
def forward(self, x: torch.Tensor):
@@ -243,30 +240,21 @@ def __init__(
243240
self.num_attention_heads_per_partition = dist_utils.divide(
244241
num_heads, self.tp_size)
245242

246-
if use_data_parallel:
247-
self.qkv = ReplicatedLinear(embed_dim,
248-
self.hidden_size_per_attention_head *
249-
3 * num_heads,
250-
bias=True,
251-
quant_config=quant_config,
252-
prefix=f"{prefix}.qkv")
253-
254-
else:
255-
self.qkv = QKVParallelLinear(
256-
hidden_size=embed_dim,
257-
head_size=self.hidden_size_per_attention_head,
258-
total_num_heads=num_heads,
259-
total_num_kv_heads=num_heads,
260-
bias=True,
261-
quant_config=quant_config,
262-
prefix=f"{prefix}.qkv")
263-
264-
cls_proj = (ReplicatedLinear
265-
if use_data_parallel else RowParallelLinear)
266-
self.proj = cls_proj(input_size=projection_size,
267-
output_size=embed_dim,
268-
quant_config=quant_config,
269-
prefix=f"{prefix}.proj")
243+
self.qkv = QKVParallelLinear(
244+
hidden_size=embed_dim,
245+
head_size=self.hidden_size_per_attention_head,
246+
total_num_heads=num_heads,
247+
total_num_kv_heads=num_heads,
248+
bias=True,
249+
quant_config=quant_config,
250+
prefix=f"{prefix}.qkv",
251+
disable_tp=use_data_parallel)
252+
253+
self.proj = RowParallelLinear(input_size=projection_size,
254+
output_size=embed_dim,
255+
quant_config=quant_config,
256+
prefix=f"{prefix}.proj",
257+
disable_tp=use_data_parallel)
270258

271259
# Detect attention implementation.
272260
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)

0 commit comments

Comments
 (0)