|
51 | 51 | from vllm.logger import init_logger
|
52 | 52 | from vllm.model_executor import SamplingMetadata
|
53 | 53 | from vllm.model_executor.layers.layernorm import RMSNorm
|
54 |
| -# yapf: disable |
55 | 54 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
56 | 55 | MergedColumnParallelLinear,
|
57 |
| - MergedReplicatedLinear, |
58 | 56 | QKVParallelLinear,
|
59 |
| - ReplicatedLinear, |
60 | 57 | RowParallelLinear)
|
61 |
| -# yapf: enable |
62 | 58 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
63 | 59 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
64 | 60 | from vllm.model_executor.models.module_mapping import MultiModelKeys
|
@@ -174,20 +170,22 @@ def __init__(
|
174 | 170 | use_data_parallel: bool = False,
|
175 | 171 | ):
|
176 | 172 | super().__init__()
|
177 |
| - cls_gate_up = (MergedReplicatedLinear |
178 |
| - if use_data_parallel else MergedColumnParallelLinear) |
179 |
| - self.gate_up_proj = cls_gate_up(input_size=in_features, |
180 |
| - output_sizes=[hidden_features] * 2, |
181 |
| - bias=bias, |
182 |
| - quant_config=quant_config, |
183 |
| - prefix=f"{prefix}.gate_up_proj") |
184 |
| - cls_down = (ReplicatedLinear |
185 |
| - if use_data_parallel else RowParallelLinear) |
186 |
| - self.down_proj = cls_down(hidden_features, |
187 |
| - in_features, |
188 |
| - bias=bias, |
189 |
| - quant_config=quant_config, |
190 |
| - prefix=f"{prefix}.down_proj") |
| 173 | + self.gate_up_proj = MergedColumnParallelLinear( |
| 174 | + input_size=in_features, |
| 175 | + output_sizes=[hidden_features] * 2, |
| 176 | + bias=bias, |
| 177 | + quant_config=quant_config, |
| 178 | + prefix=f"{prefix}.gate_up_proj", |
| 179 | + disable_tp=use_data_parallel, |
| 180 | + ) |
| 181 | + self.down_proj = RowParallelLinear( |
| 182 | + hidden_features, |
| 183 | + in_features, |
| 184 | + bias=bias, |
| 185 | + quant_config=quant_config, |
| 186 | + prefix=f"{prefix}.down_proj", |
| 187 | + disable_tp=use_data_parallel, |
| 188 | + ) |
191 | 189 | self.act_fn = SiluAndMul()
|
192 | 190 |
|
193 | 191 | def forward(self, x: torch.Tensor):
|
@@ -234,48 +232,32 @@ def __init__(
|
234 | 232 | # Per attention head and per partition values.
|
235 | 233 | self.tp_size = (1 if use_data_parallel else
|
236 | 234 | get_tensor_model_parallel_world_size())
|
237 |
| - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() |
| 235 | + self.tp_rank = (0 if use_data_parallel else |
| 236 | + parallel_state.get_tensor_model_parallel_rank()) |
238 | 237 | self.hidden_size_per_attention_head = dist_utils.divide(
|
239 | 238 | projection_size, num_heads)
|
240 | 239 | self.num_attention_heads_per_partition = dist_utils.divide(
|
241 | 240 | num_heads, self.tp_size)
|
242 | 241 |
|
243 |
| - if use_data_parallel: |
244 |
| - self.qkv = ReplicatedLinear( |
245 |
| - input_size=embed_dim, |
246 |
| - output_size=3 * projection_size, |
247 |
| - bias=False, |
248 |
| - quant_config=quant_config, |
249 |
| - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
250 |
| - prefix=f"{prefix}.qkv_proj" |
251 |
| - if quant_config else f"{prefix}.qkv", |
252 |
| - ) |
253 |
| - self.proj = ReplicatedLinear( |
254 |
| - input_size=projection_size, |
255 |
| - output_size=embed_dim, |
256 |
| - quant_config=quant_config, |
257 |
| - prefix=f"{prefix}.proj", |
258 |
| - bias=False, |
259 |
| - ) |
260 |
| - else: |
261 |
| - self.qkv = QKVParallelLinear( |
262 |
| - hidden_size=embed_dim, |
263 |
| - head_size=self.hidden_size_per_attention_head, |
264 |
| - total_num_heads=num_heads, |
265 |
| - total_num_kv_heads=num_heads, |
266 |
| - bias=False, |
267 |
| - quant_config=quant_config, |
268 |
| - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
269 |
| - prefix=f"{prefix}.qkv_proj" |
270 |
| - if quant_config else f"{prefix}.qkv", |
271 |
| - ) |
272 |
| - self.proj = RowParallelLinear( |
273 |
| - input_size=projection_size, |
274 |
| - output_size=embed_dim, |
275 |
| - quant_config=quant_config, |
276 |
| - prefix=f"{prefix}.proj", |
277 |
| - bias=False, |
278 |
| - ) |
| 242 | + self.qkv = QKVParallelLinear( |
| 243 | + hidden_size=embed_dim, |
| 244 | + head_size=self.hidden_size_per_attention_head, |
| 245 | + total_num_heads=num_heads, |
| 246 | + total_num_kv_heads=num_heads, |
| 247 | + bias=False, |
| 248 | + quant_config=quant_config, |
| 249 | + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg |
| 250 | + prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", |
| 251 | + disable_tp=use_data_parallel, |
| 252 | + ) |
| 253 | + self.proj = RowParallelLinear( |
| 254 | + input_size=projection_size, |
| 255 | + output_size=embed_dim, |
| 256 | + quant_config=quant_config, |
| 257 | + prefix=f"{prefix}.proj", |
| 258 | + bias=False, |
| 259 | + disable_tp=use_data_parallel, |
| 260 | + ) |
279 | 261 |
|
280 | 262 | # Detect attention implementation.
|
281 | 263 | self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
@@ -494,41 +476,31 @@ def __init__(
|
494 | 476 | ) -> None:
|
495 | 477 | super().__init__()
|
496 | 478 | self.hidden_size = d_model
|
497 |
| - if use_data_parallel: |
498 |
| - self.proj = ReplicatedLinear( |
499 |
| - input_size=self.hidden_size, |
500 |
| - output_size=self.hidden_size, |
501 |
| - bias=bias, |
502 |
| - quant_config=quant_config, |
503 |
| - prefix=f"{prefix}.proj", |
504 |
| - ) |
505 |
| - else: |
506 |
| - self.proj = ColumnParallelLinear( |
507 |
| - self.hidden_size, |
508 |
| - self.hidden_size, |
509 |
| - bias=bias, |
510 |
| - gather_output=True, |
511 |
| - quant_config=quant_config, |
512 |
| - prefix=f"{prefix}.proj", |
513 |
| - ) |
| 479 | + self.proj = ColumnParallelLinear( |
| 480 | + self.hidden_size, |
| 481 | + self.hidden_size, |
| 482 | + bias=bias, |
| 483 | + gather_output=True, |
| 484 | + quant_config=quant_config, |
| 485 | + prefix=f"{prefix}.proj", |
| 486 | + disable_tp=use_data_parallel, |
| 487 | + ) |
514 | 488 | self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
515 |
| - cls_gate_up = (MergedReplicatedLinear |
516 |
| - if use_data_parallel else MergedColumnParallelLinear) |
517 |
| - self.gate_up_proj = cls_gate_up( |
| 489 | + self.gate_up_proj = MergedColumnParallelLinear( |
518 | 490 | input_size=self.hidden_size,
|
519 | 491 | output_sizes=[context_dim] * 2,
|
520 | 492 | bias=bias,
|
521 | 493 | quant_config=quant_config,
|
522 | 494 | prefix=f"{prefix}.gate_up_proj",
|
| 495 | + disable_tp=use_data_parallel, |
523 | 496 | )
|
524 |
| - cls_down = (ReplicatedLinear |
525 |
| - if use_data_parallel else RowParallelLinear) |
526 |
| - self.down_proj = cls_down( |
| 497 | + self.down_proj = RowParallelLinear( |
527 | 498 | context_dim,
|
528 | 499 | self.hidden_size,
|
529 | 500 | bias=bias,
|
530 | 501 | quant_config=quant_config,
|
531 | 502 | prefix=f"{prefix}.down_proj",
|
| 503 | + disable_tp=use_data_parallel, |
532 | 504 | )
|
533 | 505 | self.act_fn = SiluAndMul()
|
534 | 506 | self.extra_activation_func = nn.GELU()
|
|
0 commit comments