|
30 | 30 | GemmaRMSNorm as Qwen3NextRMSNorm)
|
31 | 31 | # yapf: enable
|
32 | 32 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
33 |
| - MergedColumnParallelLinear, |
34 | 33 | QKVParallelLinear,
|
35 | 34 | ReplicatedLinear,
|
36 | 35 | RowParallelLinear)
|
@@ -254,12 +253,20 @@ def __init__(
|
254 | 253 | # projection of the input hidden states
|
255 | 254 | self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
|
256 | 255 | self.projection_size_ba = self.num_v_heads * 2
|
257 |
| - self.in_proj = MergedColumnParallelLinear( |
| 256 | + self.in_proj_qkvz = ColumnParallelLinear( |
258 | 257 | input_size=self.hidden_size,
|
259 |
| - output_sizes=[self.projection_size_qkvz, self.projection_size_ba], |
| 258 | + output_size=self.projection_size_qkvz, |
260 | 259 | bias=False,
|
261 | 260 | quant_config=quant_config,
|
262 |
| - prefix=f"{prefix}.in_proj", |
| 261 | + prefix=f"{prefix}.in_proj_qkvz", |
| 262 | + ) |
| 263 | + # ba_proj doesn't support blockwise fp8 quantization. |
| 264 | + self.in_proj_ba = ColumnParallelLinear( |
| 265 | + input_size=self.hidden_size, |
| 266 | + output_size=self.projection_size_ba, |
| 267 | + bias=False, |
| 268 | + quant_config=quant_config, |
| 269 | + prefix=f"{prefix}.in_proj_ba", |
263 | 270 | )
|
264 | 271 |
|
265 | 272 | query_key_settings = (self.key_dim, 0, False)
|
@@ -420,19 +427,14 @@ def _forward(
|
420 | 427 | ssm_state = self_kv_cache[1]
|
421 | 428 | num_actual_tokens = attn_metadata.num_actual_tokens
|
422 | 429 | num_accepted_tokens = attn_metadata.num_accepted_tokens
|
423 |
| - |
424 |
| - # 1. Set up dimensions for reshapes later |
425 |
| - projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) |
426 | 430 | if spec_token_masks is not None:
|
427 | 431 | spec_token_masks = spec_token_masks[:num_actual_tokens]
|
428 |
| - projected_states_qkvz, projected_states_ba = torch.split( |
429 |
| - projected_states, |
430 |
| - [ |
431 |
| - self.projection_size_qkvz // self.tp_size, |
432 |
| - self.projection_size_ba // self.tp_size |
433 |
| - ], |
434 |
| - dim=-1, |
435 |
| - ) |
| 432 | + |
| 433 | + # 1. Set up dimensions for reshapes later |
| 434 | + projected_states_qkvz, _ = self.in_proj_qkvz( |
| 435 | + hidden_states[:num_actual_tokens]) |
| 436 | + projected_states_ba, _ = self.in_proj_ba( |
| 437 | + hidden_states[:num_actual_tokens]) |
436 | 438 | query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
437 | 439 | projected_states_qkvz, projected_states_ba)
|
438 | 440 | query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
|
@@ -976,8 +978,6 @@ def load_weights(self, weights: Iterable[tuple[str,
|
976 | 978 | ("qkv_proj", "v_proj", "v"),
|
977 | 979 | ("gate_up_proj", "gate_proj", 0),
|
978 | 980 | ("gate_up_proj", "up_proj", 1),
|
979 |
| - ("in_proj", "in_proj_qkvz", 0), |
980 |
| - ("in_proj", "in_proj_ba", 1), |
981 | 981 | ]
|
982 | 982 |
|
983 | 983 | params_dict = dict(self.named_parameters())
|
@@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
1055 | 1055 | "v_proj",
|
1056 | 1056 | ],
|
1057 | 1057 | "gate_up_proj": ["gate_proj", "up_proj"],
|
1058 |
| - "in_proj": ["in_proj_qkvz", "in_proj_ba"], |
1059 | 1058 | }
|
1060 | 1059 |
|
1061 | 1060 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
0 commit comments