Skip to content

Commit 18c447e

Browse files
dongfengydominicshanshan
authored andcommitted
[None][fix] Make TP working for Triton MOE (in additional to EP we are using) (NVIDIA#6722)
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent 89c41fc commit 18c447e

File tree

3 files changed

+93
-32
lines changed

3 files changed

+93
-32
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py

Lines changed: 90 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import triton
1111
import triton.language as tl
1212

13+
from tensorrt_llm.math_utils import ceil_div
14+
1315
IS_TRITON_KERNELS_AVAILABLE = False
1416
# We expect to find triton_kernels under $TRITON_ROOT/python/triton_kernels
1517
# Triton upstream commit f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f has been verified.
@@ -625,7 +627,14 @@ class TritonMXFP4FusedMoEQuantScales(NamedTuple):
625627

626628

627629
def swizzle_weight_and_scale(w: torch.Tensor, w_scale: torch.Tensor):
628-
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
630+
# (num_experts, in_dim//2, out_dim)
631+
w_shape = w.shape
632+
# (num_experts, in_dim//32, out_dim)
633+
w_scale_shape = w_scale.shape
634+
assert w_shape[0] == w_scale_shape[0]
635+
assert w_shape[1] * 2 == w_scale_shape[1] * 32
636+
assert w_shape[2] == w_scale_shape[2]
637+
w = maybe_update_stride(w)
629638
#num_warps = 4 if batch <= 512 else 8
630639
num_warps = int(os.getenv("TRITON_MOE_MXFP4_NUM_WARPS", 4))
631640
assert num_warps in [4, 8], \
@@ -662,6 +671,8 @@ def __init__(self, activation_dtype):
662671
assert activation_dtype in [torch.float8_e4m3fn, torch.bfloat16], \
663672
f"TritonMXFP4FusedMoEMethod only supports float8_e4m3fn or bfloat16 activation, got {activation_dtype}"
664673
self.activation_dtype = activation_dtype
674+
self.in_dim_padding_multiple = 128
675+
self.out_dim_padding_multiple = 256
665676

666677
def create_weights(self, module: torch.nn.Module):
667678
weight_dtype = torch.uint8
@@ -673,11 +684,11 @@ def create_weights(self, module: torch.nn.Module):
673684
module.intermediate_size_per_partition * 2,
674685
)
675686

687+
# Full scale is loaded at the beginning, later we will slice properly for TP
676688
w3_w1_scale_shape = (
677689
w3_w1_weight_shape[0],
678-
w3_w1_weight_shape[1] //
679-
16, # block size of 32 for mxfp4, we already divided by 2 before so only divide by 16
680-
w3_w1_weight_shape[2],
690+
ceil_div(module.hidden_size, 32), # block size of 32 for mxfp4
691+
module.intermediate_size * 2,
681692
)
682693

683694
# The Triton kernel accepts the w2_weight in (num_experts, intermediate_dim, hidden_dim) format
@@ -688,27 +699,14 @@ def create_weights(self, module: torch.nn.Module):
688699
module.hidden_size,
689700
)
690701

702+
# Full scale is loaded at the beginning, later we will slice properly for TP
691703
w2_scale_shape = (
692704
w2_weight_shape[0],
693-
w2_weight_shape[1] //
694-
16, # block size of 32 for mxfp4, we already divided by 2 before so only divide by 16
705+
ceil_div(module.intermediate_size,
706+
32), # block size of 32 for mxfp4
695707
w2_weight_shape[2],
696708
)
697709

698-
def _check_shape_requirement(shape):
699-
# Reject shapes that may cause kernels to fail
700-
# For hidden_size = 2880 and intermediate_size = 2880,
701-
# we have w3_w1_weight_shape = (?, 1440, 5760)
702-
# and w2_weight_shape = (?, 1440, 2880).
703-
# Note that the div 2 here is because we are using mxfp4 which packs two values into one byte.
704-
# This check allows the 2880 case with tp1 to pass while rejecting larger tp.
705-
assert len(shape) == 3
706-
assert shape[1] % 32 == 0 and shape[
707-
2] % 32 == 0, "Shape not well-supported by Triton kernel, try EP instead"
708-
709-
_check_shape_requirement(w3_w1_weight_shape)
710-
_check_shape_requirement(w2_weight_shape)
711-
712710
FusedMoEMethodBase.create_weights(self,
713711
module,
714712
weight_dtype,
@@ -1011,25 +1009,63 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
10111009
tmp_w3_w1_weight_scale)
10121010

10131011
# For Hopper style swizzle, we need to pad the out dim to multiple of 256 otherwise it sometimes produces nan
1014-
def _maybe_pad_weight_and_scale(weight, scale=None):
1012+
def _maybe_pad_weight_and_scale(weight,
1013+
scale=None,
1014+
in_dim_padding_offset=0):
1015+
# Both weight and bias are handled here
1016+
assert weight.dim() in [2, 3], "Weight should be 2D or 3D tensor"
1017+
# out_dim padding is only required for Hopper
10151018
if torch.cuda.get_device_capability()[0] == 9:
1016-
# Both weight and bias are handled here
1017-
assert weight.dim() in [2,
1018-
3], "Weight should be 2D or 3D tensor"
1019-
10201019
out_dim = weight.shape[-1]
10211020
assert scale is None or scale.shape[
10221021
-1] == out_dim, "Out dim of weight and scale should match"
1023-
pad_size = (256 - out_dim % 256) % 256
1022+
pad_size = (self.out_dim_padding_multiple -
1023+
out_dim % self.out_dim_padding_multiple
1024+
) % self.out_dim_padding_multiple
10241025
weight = F.pad(
10251026
weight,
10261027
(0, pad_size)) # Pad the last dimension on right side
10271028
if scale is not None:
10281029
scale = F.pad(scale, (0, pad_size))
1030+
# in_dim padding is always required when we have TP because of mxfp4 scale block size
1031+
# We only do in_dim padding for weights but not for bias
1032+
if weight.dim() == 3:
1033+
in_dim = weight.shape[
1034+
-2] * 2 # mxfp4 packs two values into one byte
1035+
assert scale is None or scale.shape[-2] == ceil_div(
1036+
in_dim, 32), "In dim of weight and scale should match"
1037+
pad_size = (self.in_dim_padding_multiple -
1038+
in_dim % self.in_dim_padding_multiple
1039+
) % self.in_dim_padding_multiple
1040+
assert pad_size % 2 == 0
1041+
pad_size //= 2 # pad_size is in mxfp4 units
1042+
assert in_dim_padding_offset % 2 == 0
1043+
in_dim_padding_offset //= 2
1044+
assert in_dim_padding_offset <= pad_size, "TP offset larger than pad size"
1045+
weight = F.pad(weight, (0, 0, in_dim_padding_offset,
1046+
pad_size - in_dim_padding_offset))
1047+
assert scale is not None # Bias won't enter this branch
1048+
new_in_dim = weight.shape[-2] * 2
1049+
assert new_in_dim % 32 == 0
1050+
new_scale_in_dim = new_in_dim // 32
1051+
scale_pad_size = new_scale_in_dim - scale.shape[-2]
1052+
assert scale_pad_size >= 0
1053+
scale = F.pad(scale, (0, 0, 0, scale_pad_size))
1054+
10291055
return (weight, scale) if scale is not None else weight
10301056

10311057
# Handle w3_w1_weight
10321058

1059+
# Slice scales for TP
1060+
tp_slice_start = module.intermediate_size_per_partition * module.tp_rank
1061+
tp_slice_end = tp_slice_start + module.intermediate_size_per_partition
1062+
#(num_experts, in_dim / 32, out_dim)
1063+
assert tmp_w3_w1_weight_scale.dim() == 3
1064+
assert tmp_w3_w1_weight_scale.shape[-1] == module.intermediate_size * 2
1065+
# The scale is already shuffled
1066+
tmp_w3_w1_weight_scale = tmp_w3_w1_weight_scale[:, :, tp_slice_start *
1067+
2:tp_slice_end * 2]
1068+
10331069
tmp_w3_w1_weight, tmp_w3_w1_weight_scale = _maybe_pad_weight_and_scale(
10341070
module.w3_w1_weight, tmp_w3_w1_weight_scale)
10351071

@@ -1045,8 +1081,21 @@ def _maybe_pad_weight_and_scale(weight, scale=None):
10451081

10461082
# Handle w2_weight
10471083

1084+
# Slice scales for TP
1085+
# TP might make the weight start from half of the mxfp4 32 block
1086+
# For example, if we start from index 20, there are 12 elements in the first block instead of 32
1087+
# We need to pad 20 elements to the first block
1088+
self.w2_tp_offset = tp_slice_start % 32
1089+
assert tmp_w2_weight_scale.dim() == 3
1090+
# assert tmp_w2_weight_scale.shape[-2] * 32 == module.intermediate_size
1091+
# We skip this assert to allow intermidiate_size not divisible by 32, this is used in the unit test to test TP shapes in a single gpu
1092+
scale_slice_start = tp_slice_start // 32
1093+
scale_slice_end = (tp_slice_end - 1) // 32 + 1
1094+
tmp_w2_weight_scale = tmp_w2_weight_scale[:, scale_slice_start:
1095+
scale_slice_end, :]
1096+
10481097
tmp_w2_weight, tmp_w2_weight_scale = _maybe_pad_weight_and_scale(
1049-
module.w2_weight, tmp_w2_weight_scale)
1098+
module.w2_weight, tmp_w2_weight_scale, self.w2_tp_offset)
10501099

10511100
module._parameters.pop('w2_weight', None)
10521101
module._parameters.pop('fc2_dequant', None)
@@ -1116,6 +1165,18 @@ def apply(self, module: torch.nn.Module, x: torch.Tensor,
11161165

11171166
# Step 2: Gemm1
11181167
# Setup quantization context
1168+
def _maybe_pad_activation(hidden_states, in_dim_padding_offset):
1169+
assert hidden_states.dim() == 2, "Hidden states should be 2D tensor"
1170+
in_dim = hidden_states.shape[-1]
1171+
pad_size_in = (self.in_dim_padding_multiple -
1172+
in_dim % self.in_dim_padding_multiple
1173+
) % self.in_dim_padding_multiple
1174+
assert in_dim_padding_offset <= pad_size_in
1175+
padding = (in_dim_padding_offset,
1176+
pad_size_in - in_dim_padding_offset)
1177+
hidden_states = F.pad(hidden_states, padding)
1178+
return hidden_states
1179+
11191180
if self.activation_dtype == torch.float8_e4m3fn:
11201181
flex_ctx_1 = FlexCtx(
11211182
lhs_data=InFlexData(scale=hidden_states_scale), )
@@ -1129,6 +1190,7 @@ def apply(self, module: torch.nn.Module, x: torch.Tensor,
11291190
# Call the Triton gemm kernel, which also does permutation and activation
11301191
alpha = module.swiglu_alpha or 1.0
11311192
beta = module.swiglu_beta or 0.0
1193+
hidden_states = _maybe_pad_activation(hidden_states, 0)
11321194
if beta == 1.0:
11331195
act = FusedActivation(
11341196
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn,
@@ -1183,6 +1245,7 @@ def _maybe_remove_padding(gemm_output, expected_size):
11831245
out_dtype=module.dtype)
11841246

11851247
# Call the Triton kernel, which also does finalization
1248+
act_out = _maybe_pad_activation(act_out, self.w2_tp_offset)
11861249
gemm2_output = matmul_ogs(act_out,
11871250
gemm2_weights,
11881251
module.w2_bias if module.bias else None,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,10 +2565,6 @@ def test_w4_4gpus(self, moe_backend, tp_size, pp_size, ep_size,
25652565
if moe_backend == "TRITON":
25662566
if not IS_TRITON_KERNELS_AVAILABLE:
25672567
pytest.skip("Triton kernels are not available")
2568-
if tp_size != ep_size:
2569-
pytest.skip(
2570-
"TRITON moe backend currently doesn't supported mxfp4 tp for this size"
2571-
)
25722568

25732569
pytorch_config = dict(
25742570
disable_overlap_scheduler=not overlap_scheduler,

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1757,8 +1757,10 @@ def ref():
17571757
@pytest.mark.parametrize(
17581758
"hidden_size, intermediate_size",
17591759
[
1760-
(256, 256),
17611760
(2880, 2880),
1761+
(2880, 1440),
1762+
(2880, 720),
1763+
(2880, 360),
17621764
],
17631765
)
17641766
@pytest.mark.parametrize("fp8_activation", [True, False])

0 commit comments

Comments
 (0)