Skip to content

Commit 93e623b

Browse files
achartierjoyang-nv
authored andcommitted
[https://nvbugs/5449155][fix] Fix DeepSeek R1 weight loading for TP16 (#6913)
Signed-off-by: Aurelien Chartier <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 21291f3 commit 93e623b

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
650650
load_weights_vanilla_helper(module, weights)
651651

652652
scale_name = self._get_scale_name(weights)
653-
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
654-
module.tp_rank,
655-
module.tp_mode).squeeze()
653+
full_weight_scale = weights[0][scale_name]
654+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
655+
if full_weight_scale.dim() == 4:
656+
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
657+
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
658+
module.tp_rank, module.tp_mode)
656659
copy_weight(module.weight_scale, weight_scale)
657660
if "input_scale" in weights[0]:
658661
copy_weight(module.input_scale, weights[0]["input_scale"])
@@ -665,13 +668,23 @@ def load_weights_fused_qkv_linear(self, module: Linear,
665668
fused_weight = torch.cat((q_weight, k_weight, v_weight))
666669

667670
scale_name = self._get_scale_name(weights)
668-
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
671+
full_q_scale = weights[0][scale_name]
672+
full_k_scale = weights[1][scale_name]
673+
full_v_scale = weights[2][scale_name]
674+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
675+
if full_q_scale.dim() == 4:
676+
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
677+
if full_k_scale.dim() == 4:
678+
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
679+
if full_v_scale.dim() == 4:
680+
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
681+
q_scale = load_weight_shard(full_q_scale, module.tp_size,
669682
module.tp_rank, module.tp_mode)
670-
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
683+
k_scale = load_weight_shard(full_k_scale, module.tp_size,
671684
module.tp_rank, module.tp_mode)
672-
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
685+
v_scale = load_weight_shard(full_v_scale, module.tp_size,
673686
module.tp_rank, module.tp_mode)
674-
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
687+
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
675688

676689
copy_weight(module.weight, fused_weight)
677690
copy_weight(module.weight_scale, fused_fp8_block_scale)
@@ -683,11 +696,18 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
683696
fused_weight = torch.cat((gate_weight, up_weight))
684697

685698
scale_name = self._get_scale_name(weights)
686-
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
699+
full_left_scale = weights[0][scale_name]
700+
full_right_scale = weights[1][scale_name]
701+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
702+
if full_left_scale.dim() == 4:
703+
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
704+
if full_right_scale.dim() == 4:
705+
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
706+
left_scale = load_weight_shard(full_left_scale, module.tp_size,
687707
module.tp_rank, module.tp_mode)
688-
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
708+
right_scale = load_weight_shard(full_right_scale, module.tp_size,
689709
module.tp_rank, module.tp_mode)
690-
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
710+
fused_scale = torch.cat([left_scale, right_scale], dim=0)
691711
copy_weight(module.weight, fused_weight)
692712
copy_weight(module.weight_scale, fused_scale)
693713

0 commit comments

Comments
 (0)