@@ -650,9 +650,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
650650 load_weights_vanilla_helper (module , weights )
651651
652652 scale_name = self ._get_scale_name (weights )
653- weight_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
654- module .tp_rank ,
655- module .tp_mode ).squeeze ()
653+ full_weight_scale = weights [0 ][scale_name ]
654+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
655+ if full_weight_scale .dim () == 4 :
656+ full_weight_scale = full_weight_scale .squeeze (1 ).squeeze (- 1 )
657+ weight_scale = load_weight_shard (full_weight_scale , module .tp_size ,
658+ module .tp_rank , module .tp_mode )
656659 copy_weight (module .weight_scale , weight_scale )
657660 if "input_scale" in weights [0 ]:
658661 copy_weight (module .input_scale , weights [0 ]["input_scale" ])
@@ -665,13 +668,23 @@ def load_weights_fused_qkv_linear(self, module: Linear,
665668 fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
666669
667670 scale_name = self ._get_scale_name (weights )
668- q_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
671+ full_q_scale = weights [0 ][scale_name ]
672+ full_k_scale = weights [1 ][scale_name ]
673+ full_v_scale = weights [2 ][scale_name ]
674+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
675+ if full_q_scale .dim () == 4 :
676+ full_q_scale = full_q_scale .squeeze (1 ).squeeze (- 1 )
677+ if full_k_scale .dim () == 4 :
678+ full_k_scale = full_k_scale .squeeze (1 ).squeeze (- 1 )
679+ if full_v_scale .dim () == 4 :
680+ full_v_scale = full_v_scale .squeeze (1 ).squeeze (- 1 )
681+ q_scale = load_weight_shard (full_q_scale , module .tp_size ,
669682 module .tp_rank , module .tp_mode )
670- k_scale = load_weight_shard (weights [ 1 ][ scale_name ] , module .tp_size ,
683+ k_scale = load_weight_shard (full_k_scale , module .tp_size ,
671684 module .tp_rank , module .tp_mode )
672- v_scale = load_weight_shard (weights [ 2 ][ scale_name ] , module .tp_size ,
685+ v_scale = load_weight_shard (full_v_scale , module .tp_size ,
673686 module .tp_rank , module .tp_mode )
674- fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale )). squeeze ()
687+ fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale ))
675688
676689 copy_weight (module .weight , fused_weight )
677690 copy_weight (module .weight_scale , fused_fp8_block_scale )
@@ -683,11 +696,18 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
683696 fused_weight = torch .cat ((gate_weight , up_weight ))
684697
685698 scale_name = self ._get_scale_name (weights )
686- left_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
699+ full_left_scale = weights [0 ][scale_name ]
700+ full_right_scale = weights [1 ][scale_name ]
701+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
702+ if full_left_scale .dim () == 4 :
703+ full_left_scale = full_left_scale .squeeze (1 ).squeeze (- 1 )
704+ if full_right_scale .dim () == 4 :
705+ full_right_scale = full_right_scale .squeeze (1 ).squeeze (- 1 )
706+ left_scale = load_weight_shard (full_left_scale , module .tp_size ,
687707 module .tp_rank , module .tp_mode )
688- right_scale = load_weight_shard (weights [ 1 ][ scale_name ] , module .tp_size ,
708+ right_scale = load_weight_shard (full_right_scale , module .tp_size ,
689709 module .tp_rank , module .tp_mode )
690- fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 ). squeeze ()
710+ fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 )
691711 copy_weight (module .weight , fused_weight )
692712 copy_weight (module .weight_scale , fused_scale )
693713
0 commit comments