Skip to content

Commit ef7eefe

Browse files
authored
[Qwen] Add fp8 checkpoint support for qwen3-next. (vllm-project#25079)
Signed-off-by: Tao He <[email protected]>
1 parent 350c94d commit ef7eefe

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

vllm/model_executor/models/qwen3_next.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
GemmaRMSNorm as Qwen3NextRMSNorm)
3131
# yapf: enable
3232
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
33-
MergedColumnParallelLinear,
3433
QKVParallelLinear,
3534
ReplicatedLinear,
3635
RowParallelLinear)
@@ -254,12 +253,20 @@ def __init__(
254253
# projection of the input hidden states
255254
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
256255
self.projection_size_ba = self.num_v_heads * 2
257-
self.in_proj = MergedColumnParallelLinear(
256+
self.in_proj_qkvz = ColumnParallelLinear(
258257
input_size=self.hidden_size,
259-
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
258+
output_size=self.projection_size_qkvz,
260259
bias=False,
261260
quant_config=quant_config,
262-
prefix=f"{prefix}.in_proj",
261+
prefix=f"{prefix}.in_proj_qkvz",
262+
)
263+
# ba_proj doesn't support blockwise fp8 quantization.
264+
self.in_proj_ba = ColumnParallelLinear(
265+
input_size=self.hidden_size,
266+
output_size=self.projection_size_ba,
267+
bias=False,
268+
quant_config=quant_config,
269+
prefix=f"{prefix}.in_proj_ba",
263270
)
264271

265272
query_key_settings = (self.key_dim, 0, False)
@@ -420,19 +427,14 @@ def _forward(
420427
ssm_state = self_kv_cache[1]
421428
num_actual_tokens = attn_metadata.num_actual_tokens
422429
num_accepted_tokens = attn_metadata.num_accepted_tokens
423-
424-
# 1. Set up dimensions for reshapes later
425-
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
426430
if spec_token_masks is not None:
427431
spec_token_masks = spec_token_masks[:num_actual_tokens]
428-
projected_states_qkvz, projected_states_ba = torch.split(
429-
projected_states,
430-
[
431-
self.projection_size_qkvz // self.tp_size,
432-
self.projection_size_ba // self.tp_size
433-
],
434-
dim=-1,
435-
)
432+
433+
# 1. Set up dimensions for reshapes later
434+
projected_states_qkvz, _ = self.in_proj_qkvz(
435+
hidden_states[:num_actual_tokens])
436+
projected_states_ba, _ = self.in_proj_ba(
437+
hidden_states[:num_actual_tokens])
436438
query, key, value, z, b, a = self.fix_query_key_value_ordering(
437439
projected_states_qkvz, projected_states_ba)
438440
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
@@ -976,8 +978,6 @@ def load_weights(self, weights: Iterable[tuple[str,
976978
("qkv_proj", "v_proj", "v"),
977979
("gate_up_proj", "gate_proj", 0),
978980
("gate_up_proj", "up_proj", 1),
979-
("in_proj", "in_proj_qkvz", 0),
980-
("in_proj", "in_proj_ba", 1),
981981
]
982982

983983
params_dict = dict(self.named_parameters())
@@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
10551055
"v_proj",
10561056
],
10571057
"gate_up_proj": ["gate_proj", "up_proj"],
1058-
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
10591058
}
10601059

10611060
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/qwen3_next_mtp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
6363
self.config.hidden_size,
6464
gather_output=True,
6565
bias=False,
66-
return_bias=False)
66+
return_bias=False,
67+
quant_config=quant_config,
68+
prefix=f'{prefix}.fc')
6769

6870
self.layers = torch.nn.ModuleList(
6971
Qwen3NextDecoderLayer(
@@ -72,7 +74,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
7274
model_config=model_config,
7375
cache_config=cache_config,
7476
quant_config=quant_config,
75-
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
77+
prefix=f'{prefix}.layers.{idx}',
7678
) for idx in range(self.num_mtp_layers))
7779

7880
self.make_empty_intermediate_tensors = (
@@ -233,7 +235,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
233235
self.config = config
234236
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
235237
prefix=maybe_prefix(
236-
prefix, "model"))
238+
prefix, "mtp"))
237239
self.unpadded_vocab_size = config.vocab_size
238240
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
239241
config.hidden_size,

0 commit comments

Comments
 (0)