Skip to content

Commit e3b58fb

Browse files
committed
debug
Signed-off-by: xxi <[email protected]> modified: tensorrt_llm/_torch/distributed/ops.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py modified: tensorrt_llm/_torch/distributed/ops.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py
1 parent 373e11f commit e3b58fb

File tree

4 files changed

+100
-37
lines changed

4 files changed

+100
-37
lines changed

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ def reducescatter(
240240
if isinstance(input, torch.Tensor):
241241
assert input.shape[dim] == sum_split_size
242242
else:
243+
for val in input:
244+
if val is not None and val.shape[dim] != sum_split_size:
245+
print(
246+
f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}"
247+
)
243248
assert all([
244249
val.shape[dim] == sum_split_size for val in input
245250
if val is not None

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
308308
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
309309
# Disable alltoall when chunking is used
310310
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
311+
print(
312+
f"can not use alltoall due to chunking {self.calculate_num_chunks(all_rank_num_tokens)}"
313+
)
311314
return False
312315

313316
# For DeepEPLowLatency, check if tokens exceed the threshold
314317
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
315318
and all_rank_max_num_tokens > self.deep_ep_max_num_tokens):
319+
print(
320+
f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
321+
)
316322
return False
317323

324+
print(f"all to all type {self.alltoall_method_type}")
318325
return self.enable_alltoall
319326

320327
def _get_quant_method(self):
@@ -323,9 +330,18 @@ def _get_quant_method(self):
323330
if self.quant_config.layer_quant_mode.has_fp8_qdq():
324331
return FP8QDQFusedMoEMethod()
325332
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
333+
print(
334+
f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
335+
)
326336
if get_sm_version() == 100:
337+
print(
338+
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
339+
)
327340
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
328341
else:
342+
print(
343+
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
344+
)
329345
return DeepSeekFP8BlockScalesFusedMoEMethod()
330346
elif self.quant_config.layer_quant_mode.has_nvfp4():
331347
return NVFP4CutlassFusedMoEMethod()
@@ -399,6 +415,10 @@ def forward_chunk(
399415

400416
is_first_call, is_last_call = repeating_info
401417

418+
print(
419+
f"xxi shape 1: enter wide_ep forward_chunk: layer_load_balancer={self.layer_load_balancer}, is_first_call={is_first_call}, is_last_call={is_last_call}, x shape: {getattr(x, 'shape', None)}, router_logits shape: {getattr(router_logits, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, repeating_info: {repeating_info}"
420+
)
421+
402422
if self.layer_load_balancer and is_first_call:
403423
self.layer_load_balancer.start_wait_gpu_stage()
404424

@@ -475,7 +495,7 @@ def forward_chunk(
475495
self.dummy_allreduce()
476496
token_count = x.shape[0]
477497
alltoall_info = None
478-
if is_last_call:
498+
if self.layer_load_balancer and is_last_call:
479499
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
480500
)
481501
else:
@@ -650,7 +670,35 @@ def forward_chunk(
650670
)
651671

652672
# Original fused_moe call (preserved as reference)
653-
final_hidden_states = torch.ops.trtllm.fused_moe(
673+
# final_hidden_states = torch.ops.trtllm.fused_moe(
674+
# x,
675+
# token_selected_slots,
676+
# token_final_scales,
677+
# w3_w1_weight.view(weight_dtype),
678+
# None, # w3_w1_bias
679+
# w2_weight.view(weight_dtype),
680+
# None, # w2_bias
681+
# output_dtype,
682+
# quant_scales=quant_scales,
683+
# input_sf=x_sf,
684+
# swizzled_input_sf=False,
685+
# tp_size=self.tp_size,
686+
# tp_rank=self.tp_rank,
687+
# ep_size=ep_size,
688+
# ep_rank=ep_rank,
689+
# cluster_size=cluster_size,
690+
# cluster_rank=cluster_rank,
691+
# enable_alltoall=use_all_to_all,
692+
# use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
693+
# use_w4_group_scaling=use_w4_group_scaling,
694+
# min_latency_mode=False,
695+
# tune_max_num_tokens=self.tune_max_num_tokens,
696+
# tuner_num_tokens=tuner_num_tokens,
697+
# tuner_top_k=tuner_top_k,
698+
# )
699+
700+
# Use the selected backend to compute MoE with the same parameters as fused_moe
701+
final_hidden_states = self.moe_backend.run_moe(
654702
x,
655703
token_selected_slots,
656704
token_final_scales,
@@ -675,36 +723,13 @@ def forward_chunk(
675723
tune_max_num_tokens=self.tune_max_num_tokens,
676724
tuner_num_tokens=tuner_num_tokens,
677725
tuner_top_k=tuner_top_k,
726+
module=
727+
self, # Additional parameter for backend to access module properties
678728
)
679729

680-
# Use the selected backend to compute MoE with the same parameters as fused_moe
681-
# final_hidden_states = self.moe_backend.run_moe(
682-
# x,
683-
# token_selected_slots,
684-
# token_final_scales,
685-
# w3_w1_weight.view(weight_dtype),
686-
# None, # w3_w1_bias
687-
# w2_weight.view(weight_dtype),
688-
# None, # w2_bias
689-
# output_dtype,
690-
# quant_scales=quant_scales,
691-
# input_sf=x_sf,
692-
# swizzled_input_sf=False,
693-
# tp_size=self.tp_size,
694-
# tp_rank=self.tp_rank,
695-
# ep_size=ep_size,
696-
# ep_rank=ep_rank,
697-
# cluster_size=cluster_size,
698-
# cluster_rank=cluster_rank,
699-
# enable_alltoall=use_all_to_all,
700-
# use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
701-
# use_w4_group_scaling=use_w4_group_scaling,
702-
# min_latency_mode=False,
703-
# tune_max_num_tokens=self.tune_max_num_tokens,
704-
# tuner_num_tokens=tuner_num_tokens,
705-
# tuner_top_k=tuner_top_k,
706-
# module=self, # Additional parameter for backend to access module properties
707-
# )
730+
print(
731+
f"xxi shape 4 after moe backend : {getattr(x, 'shape', None)}, final_hidden_states shape: {getattr(final_hidden_states, 'shape', None)}, token_selected_slots shape: {getattr(token_selected_slots, 'shape', None)}, token_final_scales shape: {getattr(token_final_scales, 'shape', None)}, w3_w1_weight shape: {getattr(w3_w1_weight, 'shape', None)}, w2_weight shape: {getattr(w2_weight, 'shape', None)}, quant_scales: {getattr(quant_scales, 'shape', None)}, input_sf: {getattr(x_sf, 'shape', None)}, swizzled_input_sf: False, tp_size: {self.tp_size}, tp_rank: {self.tp_rank}, ep_size: {ep_size}, ep_rank: {ep_rank}, cluster_size: {cluster_size}, cluster_rank: {cluster_rank}, enable_alltoall: {use_all_to_all}, use_deepseek_fp8_block_scale: {use_deepseek_fp8_block_scale}, use_w4_group_scaling: {use_w4_group_scaling}, min_latency_mode: False, tune_max_num_tokens: {self.tune_max_num_tokens}, tuner_num_tokens: {tuner_num_tokens}, tuner_top_k: {tuner_top_k}"
732+
)
708733

709734
if self.layer_load_balancer and is_last_call:
710735
self.layer_load_balancer.start_set_cpu_stage()
@@ -784,6 +809,10 @@ def forward(
784809
all_rank_max_num_tokens=all_rank_max_num_tokens,
785810
use_dp_padding=use_dp_padding,
786811
repeating_info=(is_first_call, is_last_call))
812+
# 一行打印所有信息
813+
print(
814+
f"xxi x.shape: {getattr(x, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_num_tokens_padded: {all_rank_num_tokens_padded}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, outputs.shape: {getattr(outputs, 'shape', None)}, use_dp_padding(again): {use_dp_padding}"
815+
)
787816
outputs = self.reducescatter_or_allreduce(
788817
outputs,
789818
use_all_to_all,

tensorrt_llm/_torch/modules/fused_moe/moe_backend.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def compute_moe(
9898
Computed MoE output tensor
9999
"""
100100

101-
@abstractmethod
102101
def run_moe(
103102
self,
104103
# Positional arguments (same order as torch.ops.trtllm.fused_moe)
@@ -542,10 +541,11 @@ def __init__(self):
542541
super().__init__()
543542
# Import DeepGemm specific functions
544543
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
545-
from tensorrt_llm import deep_gemm
546-
self.deep_gemm = deep_gemm
547544
self.fp8_utils = fp8_utils
548545

546+
from .fused_moe_deepgemm import deepgemm_fp8_group_blockwise_gemm
547+
self.deepgemm_fp8_group_blockwise_gemm = deepgemm_fp8_group_blockwise_gemm
548+
549549
def finalize_tactic(
550550
self,
551551
module: Any,
@@ -664,6 +664,7 @@ def compute_moe(
664664
Note: This assumes the data has already been gathered/alltoall'd
665665
by the WideEP forward_chunk method.
666666
"""
667+
667668
# Import necessary functions for DeepGemm
668669
from .fused_moe_deepgemm import (masked_index_copy_group_quant_fp8,
669670
preprocess_after_permute, set_strides,
@@ -711,6 +712,20 @@ def compute_moe(
711712
use_fp8_block_scaling=True, # Always use block scaling for DeepGemm
712713
)
713714

715+
print(
716+
"xxi shape 2: enter deepgemm backend compute_moe \n"
717+
f"x.shape: {getattr(x, 'shape', None)}, \n"
718+
f"input_sf.shape: {getattr(input_sf, 'shape', None)}, \n"
719+
f"token_selected_slots.shape: {getattr(token_selected_slots, 'shape', None)}, \n"
720+
f"token_final_scales.shape: {getattr(token_final_scales, 'shape', None)}, \n"
721+
f"permuted_row_to_unpermuted_row_tensor.shape: {getattr(permuted_row_to_unpermuted_row_tensor, 'shape', None)}, \n"
722+
f"permuted_token_selected_experts_tensor.shape: {getattr(permuted_token_selected_experts_tensor, 'shape', None)}, \n"
723+
f"permuted_data_tensor.shape: {getattr(permuted_data_tensor, 'shape', None)}, \n"
724+
f"expert_first_token_offset_tensor.shape: {getattr(expert_first_token_offset_tensor, 'shape', None)}, \n"
725+
f"permuted_token_final_scales_tensor.shape: {getattr(permuted_token_final_scales_tensor, 'shape', None)}, \n"
726+
f"unpermuted_row_to_permuted_row_tensor.shape: {getattr(unpermuted_row_to_permuted_row_tensor, 'shape', None)}\n"
727+
)
728+
714729
if permuted_data_tensor.numel() == 0:
715730
return torch.zeros_like(x)
716731

@@ -750,7 +765,7 @@ def compute_moe(
750765
h1 = set_strides(workspace["workspace_1"], expert_size_per_partition,
751766
m_max, intermediate_size * 2)
752767

753-
self.deep_gemm.deepgemm_fp8_group_blockwise_gemm(
768+
self.deepgemm_fp8_group_blockwise_gemm(
754769
d=h1,
755770
a=act_input_fp8,
756771
b=w3_w1_weight,
@@ -783,7 +798,7 @@ def compute_moe(
783798
h3 = set_strides(workspace["workspace_1"], expert_size_per_partition,
784799
m_max, hidden_size)
785800

786-
self.deep_gemm.deepgemm_fp8_group_blockwise_gemm(
801+
self.deepgemm_fp8_group_blockwise_gemm(
787802
d=h3,
788803
a=act_input_fp8,
789804
b=w2_weight,
@@ -817,8 +832,19 @@ def compute_moe(
817832
ep_size,
818833
ep_rank,
819834
)
820-
821-
return final_hidden_states
835+
print(
836+
"xxi shape 3: exit deepgemm backend compute_moe \n"
837+
f"final_hidden_states.shape: {getattr(final_hidden_states, 'shape', None)}\n"
838+
f"permuted_data_tensor.shape: {getattr(permuted_data_tensor, 'shape', None)}, \n"
839+
f"token_final_scales.shape: {getattr(token_final_scales, 'shape', None)}, \n"
840+
f"unpermuted_row_to_permuted_row_tensor.shape: {getattr(unpermuted_row_to_permuted_row_tensor, 'shape', None)}, \n"
841+
f"permuted_row_to_unpermuted_row_tensor.shape: {getattr(permuted_row_to_unpermuted_row_tensor, 'shape', None)}, \n"
842+
f"expert_first_token_offset_tensor.shape: {getattr(expert_first_token_offset_tensor, 'shape', None)}, \n"
843+
f"x.shape: {getattr(x, 'shape', None)}, \n")
844+
845+
return final_hidden_states if min_latency_mode else [
846+
final_hidden_states
847+
]
822848

823849

824850
class MoEBackendSelection:

tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,9 @@ def maybe_create_moe_load_balancer(
960960
in_supported_model_arch = model_arch in moe_model_arch_list
961961
using_smart_router = mapping and mapping.moe_cluster_size > 1
962962
moe_load_balancer = nullcontext()
963+
print(
964+
f"maybe_create_moe_load_balancer: in_supported_model_arch={in_supported_model_arch}, using_ep={using_ep}, using_smart_router={using_smart_router}, model_config.moe_load_balancer={model_config.moe_load_balancer}"
965+
)
963966
if in_supported_model_arch and using_ep and not using_smart_router and model_config.moe_load_balancer is not None:
964967
model_config.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size)
965968
if model_config.moe_load_balancer.layer_updates_per_iter > 0:

0 commit comments

Comments
 (0)