10
10
import triton
11
11
import triton .language as tl
12
12
13
+ from tensorrt_llm .math_utils import ceil_div
14
+
13
15
IS_TRITON_KERNELS_AVAILABLE = False
14
16
# We expect to find triton_kernels under $TRITON_ROOT/python/triton_kernels
15
17
# Triton upstream commit f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f has been verified.
@@ -625,7 +627,14 @@ class TritonMXFP4FusedMoEQuantScales(NamedTuple):
625
627
626
628
627
629
def swizzle_weight_and_scale (w : torch .Tensor , w_scale : torch .Tensor ):
628
- w = w .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
630
+ # (num_experts, in_dim//2, out_dim)
631
+ w_shape = w .shape
632
+ # (num_experts, in_dim//32, out_dim)
633
+ w_scale_shape = w_scale .shape
634
+ assert w_shape [0 ] == w_scale_shape [0 ]
635
+ assert w_shape [1 ] * 2 == w_scale_shape [1 ] * 32
636
+ assert w_shape [2 ] == w_scale_shape [2 ]
637
+ w = maybe_update_stride (w )
629
638
#num_warps = 4 if batch <= 512 else 8
630
639
num_warps = int (os .getenv ("TRITON_MOE_MXFP4_NUM_WARPS" , 4 ))
631
640
assert num_warps in [4 , 8 ], \
@@ -662,6 +671,8 @@ def __init__(self, activation_dtype):
662
671
assert activation_dtype in [torch .float8_e4m3fn , torch .bfloat16 ], \
663
672
f"TritonMXFP4FusedMoEMethod only supports float8_e4m3fn or bfloat16 activation, got { activation_dtype } "
664
673
self .activation_dtype = activation_dtype
674
+ self .in_dim_padding_multiple = 128
675
+ self .out_dim_padding_multiple = 256
665
676
666
677
def create_weights (self , module : torch .nn .Module ):
667
678
weight_dtype = torch .uint8
@@ -673,11 +684,11 @@ def create_weights(self, module: torch.nn.Module):
673
684
module .intermediate_size_per_partition * 2 ,
674
685
)
675
686
687
+ # Full scale is loaded at the beginning, later we will slice properly for TP
676
688
w3_w1_scale_shape = (
677
689
w3_w1_weight_shape [0 ],
678
- w3_w1_weight_shape [1 ] //
679
- 16 , # block size of 32 for mxfp4, we already divided by 2 before so only divide by 16
680
- w3_w1_weight_shape [2 ],
690
+ ceil_div (module .hidden_size , 32 ), # block size of 32 for mxfp4
691
+ module .intermediate_size * 2 ,
681
692
)
682
693
683
694
# The Triton kernel accepts the w2_weight in (num_experts, intermediate_dim, hidden_dim) format
@@ -688,27 +699,14 @@ def create_weights(self, module: torch.nn.Module):
688
699
module .hidden_size ,
689
700
)
690
701
702
+ # Full scale is loaded at the beginning, later we will slice properly for TP
691
703
w2_scale_shape = (
692
704
w2_weight_shape [0 ],
693
- w2_weight_shape [ 1 ] //
694
- 16 , # block size of 32 for mxfp4, we already divided by 2 before so only divide by 16
705
+ ceil_div ( module . intermediate_size ,
706
+ 32 ), # block size of 32 for mxfp4
695
707
w2_weight_shape [2 ],
696
708
)
697
709
698
- def _check_shape_requirement (shape ):
699
- # Reject shapes that may cause kernels to fail
700
- # For hidden_size = 2880 and intermediate_size = 2880,
701
- # we have w3_w1_weight_shape = (?, 1440, 5760)
702
- # and w2_weight_shape = (?, 1440, 2880).
703
- # Note that the div 2 here is because we are using mxfp4 which packs two values into one byte.
704
- # This check allows the 2880 case with tp1 to pass while rejecting larger tp.
705
- assert len (shape ) == 3
706
- assert shape [1 ] % 32 == 0 and shape [
707
- 2 ] % 32 == 0 , "Shape not well-supported by Triton kernel, try EP instead"
708
-
709
- _check_shape_requirement (w3_w1_weight_shape )
710
- _check_shape_requirement (w2_weight_shape )
711
-
712
710
FusedMoEMethodBase .create_weights (self ,
713
711
module ,
714
712
weight_dtype ,
@@ -1011,25 +1009,63 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
1011
1009
tmp_w3_w1_weight_scale )
1012
1010
1013
1011
# For Hopper style swizzle, we need to pad the out dim to multiple of 256 otherwise it sometimes produces nan
1014
- def _maybe_pad_weight_and_scale (weight , scale = None ):
1012
+ def _maybe_pad_weight_and_scale (weight ,
1013
+ scale = None ,
1014
+ in_dim_padding_offset = 0 ):
1015
+ # Both weight and bias are handled here
1016
+ assert weight .dim () in [2 , 3 ], "Weight should be 2D or 3D tensor"
1017
+ # out_dim padding is only required for Hopper
1015
1018
if torch .cuda .get_device_capability ()[0 ] == 9 :
1016
- # Both weight and bias are handled here
1017
- assert weight .dim () in [2 ,
1018
- 3 ], "Weight should be 2D or 3D tensor"
1019
-
1020
1019
out_dim = weight .shape [- 1 ]
1021
1020
assert scale is None or scale .shape [
1022
1021
- 1 ] == out_dim , "Out dim of weight and scale should match"
1023
- pad_size = (256 - out_dim % 256 ) % 256
1022
+ pad_size = (self .out_dim_padding_multiple -
1023
+ out_dim % self .out_dim_padding_multiple
1024
+ ) % self .out_dim_padding_multiple
1024
1025
weight = F .pad (
1025
1026
weight ,
1026
1027
(0 , pad_size )) # Pad the last dimension on right side
1027
1028
if scale is not None :
1028
1029
scale = F .pad (scale , (0 , pad_size ))
1030
+ # in_dim padding is always required when we have TP because of mxfp4 scale block size
1031
+ # We only do in_dim padding for weights but not for bias
1032
+ if weight .dim () == 3 :
1033
+ in_dim = weight .shape [
1034
+ - 2 ] * 2 # mxfp4 packs two values into one byte
1035
+ assert scale is None or scale .shape [- 2 ] == ceil_div (
1036
+ in_dim , 32 ), "In dim of weight and scale should match"
1037
+ pad_size = (self .in_dim_padding_multiple -
1038
+ in_dim % self .in_dim_padding_multiple
1039
+ ) % self .in_dim_padding_multiple
1040
+ assert pad_size % 2 == 0
1041
+ pad_size //= 2 # pad_size is in mxfp4 units
1042
+ assert in_dim_padding_offset % 2 == 0
1043
+ in_dim_padding_offset //= 2
1044
+ assert in_dim_padding_offset <= pad_size , "TP offset larger than pad size"
1045
+ weight = F .pad (weight , (0 , 0 , in_dim_padding_offset ,
1046
+ pad_size - in_dim_padding_offset ))
1047
+ assert scale is not None # Bias won't enter this branch
1048
+ new_in_dim = weight .shape [- 2 ] * 2
1049
+ assert new_in_dim % 32 == 0
1050
+ new_scale_in_dim = new_in_dim // 32
1051
+ scale_pad_size = new_scale_in_dim - scale .shape [- 2 ]
1052
+ assert scale_pad_size >= 0
1053
+ scale = F .pad (scale , (0 , 0 , 0 , scale_pad_size ))
1054
+
1029
1055
return (weight , scale ) if scale is not None else weight
1030
1056
1031
1057
# Handle w3_w1_weight
1032
1058
1059
+ # Slice scales for TP
1060
+ tp_slice_start = module .intermediate_size_per_partition * module .tp_rank
1061
+ tp_slice_end = tp_slice_start + module .intermediate_size_per_partition
1062
+ #(num_experts, in_dim / 32, out_dim)
1063
+ assert tmp_w3_w1_weight_scale .dim () == 3
1064
+ assert tmp_w3_w1_weight_scale .shape [- 1 ] == module .intermediate_size * 2
1065
+ # The scale is already shuffled
1066
+ tmp_w3_w1_weight_scale = tmp_w3_w1_weight_scale [:, :, tp_slice_start *
1067
+ 2 :tp_slice_end * 2 ]
1068
+
1033
1069
tmp_w3_w1_weight , tmp_w3_w1_weight_scale = _maybe_pad_weight_and_scale (
1034
1070
module .w3_w1_weight , tmp_w3_w1_weight_scale )
1035
1071
@@ -1045,8 +1081,21 @@ def _maybe_pad_weight_and_scale(weight, scale=None):
1045
1081
1046
1082
# Handle w2_weight
1047
1083
1084
+ # Slice scales for TP
1085
+ # TP might make the weight start from half of the mxfp4 32 block
1086
+ # For example, if we start from index 20, there are 12 elements in the first block instead of 32
1087
+ # We need to pad 20 elements to the first block
1088
+ self .w2_tp_offset = tp_slice_start % 32
1089
+ assert tmp_w2_weight_scale .dim () == 3
1090
+ # assert tmp_w2_weight_scale.shape[-2] * 32 == module.intermediate_size
1091
+ # We skip this assert to allow intermidiate_size not divisible by 32, this is used in the unit test to test TP shapes in a single gpu
1092
+ scale_slice_start = tp_slice_start // 32
1093
+ scale_slice_end = (tp_slice_end - 1 ) // 32 + 1
1094
+ tmp_w2_weight_scale = tmp_w2_weight_scale [:, scale_slice_start :
1095
+ scale_slice_end , :]
1096
+
1048
1097
tmp_w2_weight , tmp_w2_weight_scale = _maybe_pad_weight_and_scale (
1049
- module .w2_weight , tmp_w2_weight_scale )
1098
+ module .w2_weight , tmp_w2_weight_scale , self . w2_tp_offset )
1050
1099
1051
1100
module ._parameters .pop ('w2_weight' , None )
1052
1101
module ._parameters .pop ('fc2_dequant' , None )
@@ -1116,6 +1165,18 @@ def apply(self, module: torch.nn.Module, x: torch.Tensor,
1116
1165
1117
1166
# Step 2: Gemm1
1118
1167
# Setup quantization context
1168
+ def _maybe_pad_activation (hidden_states , in_dim_padding_offset ):
1169
+ assert hidden_states .dim () == 2 , "Hidden states should be 2D tensor"
1170
+ in_dim = hidden_states .shape [- 1 ]
1171
+ pad_size_in = (self .in_dim_padding_multiple -
1172
+ in_dim % self .in_dim_padding_multiple
1173
+ ) % self .in_dim_padding_multiple
1174
+ assert in_dim_padding_offset <= pad_size_in
1175
+ padding = (in_dim_padding_offset ,
1176
+ pad_size_in - in_dim_padding_offset )
1177
+ hidden_states = F .pad (hidden_states , padding )
1178
+ return hidden_states
1179
+
1119
1180
if self .activation_dtype == torch .float8_e4m3fn :
1120
1181
flex_ctx_1 = FlexCtx (
1121
1182
lhs_data = InFlexData (scale = hidden_states_scale ), )
@@ -1129,6 +1190,7 @@ def apply(self, module: torch.nn.Module, x: torch.Tensor,
1129
1190
# Call the Triton gemm kernel, which also does permutation and activation
1130
1191
alpha = module .swiglu_alpha or 1.0
1131
1192
beta = module .swiglu_beta or 0.0
1193
+ hidden_states = _maybe_pad_activation (hidden_states , 0 )
1132
1194
if beta == 1.0 :
1133
1195
act = FusedActivation (
1134
1196
FnSpecs ("swiglu" , triton_kernels .swiglu .swiglu_fn ,
@@ -1183,6 +1245,7 @@ def _maybe_remove_padding(gemm_output, expected_size):
1183
1245
out_dtype = module .dtype )
1184
1246
1185
1247
# Call the Triton kernel, which also does finalization
1248
+ act_out = _maybe_pad_activation (act_out , self .w2_tp_offset )
1186
1249
gemm2_output = matmul_ogs (act_out ,
1187
1250
gemm2_weights ,
1188
1251
module .w2_bias if module .bias else None ,
0 commit comments