Skip to content

Commit 4b2d937

Browse files
dongjiyingdjydominicshanshan
authored andcommitted
feat : support duplicate_kv_weight for qwen3 blockwise scale (NVIDIA#5459)
Signed-off-by: Jiying Dong <[email protected]>
1 parent 20a2213 commit 4b2d937

File tree

6 files changed

+36
-30
lines changed

6 files changed

+36
-30
lines changed

examples/pytorch/out_of_tree_example/modeling_opt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239

240240
def load_weights(self, weights: dict):
241241
tp_size = self.model_config.mapping.tp_size
242-
head_dim = self.config.hidden_size // self.config.num_attention_heads
242+
num_kv_heads = self.model_config.pretrained_config.num_attention_heads
243243

244244
def filter_weights(prefix: str, weights: dict):
245245
result = {}
@@ -280,7 +280,7 @@ def filter_weights(prefix: str, weights: dict):
280280
k:
281281
duplicate_kv_weight(
282282
weight=v[:],
283-
head_dim=head_dim,
283+
num_kv_heads=num_kv_heads,
284284
tensor_parallel_size=tp_size)
285285
if k in ['weight', 'bias'] else v
286286
for k, v in fw.items()

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,7 @@ def forward(
329329
# minor change for Gemma3 RMSNorm.
330330
def load_weights(self, weights: Dict):
331331
tp_size = self.model_config.mapping.tp_size
332-
head_dim = getattr(
333-
self.config, "head_dim",
334-
self.config.hidden_size // self.config.num_attention_heads)
332+
num_kv_heads = self.config.num_key_value_heads
335333

336334
params_map = {
337335
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@@ -365,7 +363,7 @@ def load_weights(self, weights: Dict):
365363
k:
366364
duplicate_kv_weight(
367365
weight=v[:],
368-
head_dim=head_dim,
366+
num_kv_heads=num_kv_heads,
369367
tensor_parallel_size=tp_size)
370368
if k in ["weight", "bias"] else v
371369
for k, v in fw.items()

tensorrt_llm/_torch/models/modeling_mllama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ def load_weights(self, weights: Dict):
333333
tp_size = self.config.mapping.tp_size
334334
vision_config = self.config.pretrained_config.vision_config
335335
text_config = self.config.pretrained_config.text_config
336-
text_head_dim = text_config.hidden_size // text_config.num_attention_heads
337-
vision_head_dim = vision_config.hidden_size // vision_config.attention_heads
336+
text_config.hidden_size // text_config.num_attention_heads
337+
vision_config.hidden_size // vision_config.attention_heads
338338

339339
params_map = {
340340
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@@ -347,7 +347,7 @@ def load_weights(self, weights: Dict):
347347
# skip load weights if tie word embeddings is enabled and layer is lm_head
348348
if text_config.tie_word_embeddings and "lm_head" in name:
349349
continue
350-
head_dim = vision_head_dim if "vision_model" in name else text_head_dim
350+
num_kv_heads = vision_config.num_key_value_heads if "vision_model" in name else text_config.num_key_value_heads
351351

352352
names = name.split('.')
353353
if names[-1] in params_map:
@@ -360,7 +360,7 @@ def load_weights(self, weights: Dict):
360360
k:
361361
duplicate_kv_weight(
362362
weight=v[:],
363-
head_dim=head_dim,
363+
num_kv_heads=num_kv_heads,
364364
tensor_parallel_size=tp_size)
365365
if k in ["weight", "bias"] else v
366366
for k, v in fw.items()

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,7 @@ def load_weights(self, weights: Dict):
394394
tp_size = self.model_config.mapping.tp_size
395395
enable_attention_dp = self.model_config.mapping.enable_attention_dp
396396

397-
head_dim = getattr(
398-
self.config, "head_dim",
399-
self.config.hidden_size // self.config.num_attention_heads)
397+
num_kv_heads = self.config.num_key_value_heads
400398

401399
params_map = {
402400
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -419,11 +417,14 @@ def load_weights(self, weights: Dict):
419417
tensors_need_duplication = ["weight", "bias"]
420418
if module.quant_config.quant_mode.has_nvfp4():
421419
tensors_need_duplication.append("weight_scale")
420+
if module.quant_config.quant_mode.has_fp8_block_scales(
421+
):
422+
tensors_need_duplication.append("weight_scale_inv")
422423
if new_name in ["k_proj", "v_proj"]:
423424
fw = {
424425
k: (duplicate_kv_weight(
425426
weight=v[:],
426-
head_dim=head_dim,
427+
num_kv_heads=num_kv_heads,
427428
tensor_parallel_size=tp_size
428429
if not enable_attention_dp else 1)
429430
if k in tensors_need_duplication else v)

tensorrt_llm/_torch/models/modeling_qwen_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def __init__(
256256

257257
def load_weights(self, weights: Dict):
258258
tp_size = self.model_config.mapping.tp_size
259-
head_dim = self.config.hidden_size // self.config.num_attention_heads
259+
num_kv_heads = self.config.num_key_value_heads
260260

261261
params_map = {
262262
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@@ -281,7 +281,7 @@ def load_weights(self, weights: Dict):
281281
k:
282282
duplicate_kv_weight(
283283
weight=v[:],
284-
head_dim=head_dim,
284+
num_kv_heads=num_kv_heads,
285285
tensor_parallel_size=tp_size)
286286
if k in ["weight", "bias"] else v
287287
for k, v in fw.items()

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
9393
return func(*args, **kwargs)
9494

9595

96-
def duplicate_kv_weight(weight: torch.Tensor, head_dim: int,
96+
def duplicate_kv_weight(weight: torch.Tensor, num_kv_heads: int,
9797
tensor_parallel_size: int):
9898

99-
num_kv_heads = weight.shape[0] // head_dim
100-
10199
if num_kv_heads >= tensor_parallel_size:
102100
assert num_kv_heads % tensor_parallel_size == 0
103101
return weight
@@ -109,11 +107,15 @@ def duplicate_kv_weight(weight: torch.Tensor, head_dim: int,
109107
if weight.ndim == 1:
110108
return weight.repeat_interleave(reps)
111109

112-
# weight
113-
weight = weight.reshape(num_kv_heads, head_dim,
110+
# weight and scale
111+
assert weight.shape[0] % num_kv_heads == 0
112+
size_per_kv_head = weight.shape[0] // num_kv_heads
113+
weight = weight.reshape(num_kv_heads, size_per_kv_head,
114114
-1)[:, None, :, :].expand(num_kv_heads, reps,
115-
head_dim, weight.shape[1])
116-
return weight.reshape(num_kv_heads * reps * head_dim, -1).clone().detach()
115+
size_per_kv_head,
116+
weight.shape[1])
117+
return weight.reshape(num_kv_heads * reps * size_per_kv_head,
118+
-1).clone().detach()
117119

118120

119121
def iter_modules(
@@ -648,9 +650,9 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
648650
logger.info(f"Renamed weights with params_map: {params_map}")
649651

650652
tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
651-
head_dim = getattr(
652-
model.config, "head_dim",
653-
model.config.hidden_size // model.config.num_attention_heads)
653+
num_kv_heads = model.config.num_key_value_heads if hasattr(
654+
model.config, 'num_key_value_heads'
655+
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
654656

655657
params_map = {
656658
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
@@ -687,13 +689,18 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
687689
fw = filter_weights('.'.join(names[:-1] + [new_name]),
688690
weights)
689691
if new_name in ['k_proj', 'v_proj']:
692+
num_kv_heads_list = [num_kv_heads
693+
] * len(fw) if isinstance(
694+
num_kv_heads,
695+
int) else num_kv_heads
690696
fw = {
691697
k:
692-
duplicate_kv_weight(weight=v[:],
693-
head_dim=head_dim,
694-
tensor_parallel_size=tp_size)
698+
duplicate_kv_weight(
699+
weight=v[:],
700+
num_kv_heads=num_kv_heads_list[i],
701+
tensor_parallel_size=tp_size)
695702
if k in ["weight", "bias"] else v
696-
for k, v in fw.items()
703+
for i, (k, v) in enumerate(fw.items())
697704
}
698705

699706
module_weights.append(fw)

0 commit comments

Comments
 (0)