@@ -308,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
308
308
def can_use_alltoall (self , all_rank_num_tokens , all_rank_max_num_tokens ):
309
309
# Disable alltoall when chunking is used
310
310
if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
311
+ print (
312
+ f"can not use alltoall due to chunking { self .calculate_num_chunks (all_rank_num_tokens )} "
313
+ )
311
314
return False
312
315
313
316
# For DeepEPLowLatency, check if tokens exceed the threshold
314
317
if (self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency
315
318
and all_rank_max_num_tokens > self .deep_ep_max_num_tokens ):
319
+ print (
320
+ f"can not use alltoall due to deep_ep_max_num_tokens { all_rank_max_num_tokens } > { self .deep_ep_max_num_tokens } "
321
+ )
316
322
return False
317
323
324
+ print (f"all to all type { self .alltoall_method_type } " )
318
325
return self .enable_alltoall
319
326
320
327
def _get_quant_method (self ):
@@ -323,9 +330,18 @@ def _get_quant_method(self):
323
330
if self .quant_config .layer_quant_mode .has_fp8_qdq ():
324
331
return FP8QDQFusedMoEMethod ()
325
332
elif self .quant_config .layer_quant_mode .has_fp8_block_scales ():
333
+ print (
334
+ f"wide_ep _get_quant_method: get_sm_version()={ get_sm_version ()} "
335
+ )
326
336
if get_sm_version () == 100 :
337
+ print (
338
+ f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
339
+ )
327
340
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ()
328
341
else :
342
+ print (
343
+ f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
344
+ )
329
345
return DeepSeekFP8BlockScalesFusedMoEMethod ()
330
346
elif self .quant_config .layer_quant_mode .has_nvfp4 ():
331
347
return NVFP4CutlassFusedMoEMethod ()
@@ -399,6 +415,10 @@ def forward_chunk(
399
415
400
416
is_first_call , is_last_call = repeating_info
401
417
418
+ print (
419
+ f"xxi shape 1: enter wide_ep forward_chunk: layer_load_balancer={ self .layer_load_balancer } , is_first_call={ is_first_call } , is_last_call={ is_last_call } , x shape: { getattr (x , 'shape' , None )} , router_logits shape: { getattr (router_logits , 'shape' , None )} , use_all_to_all: { use_all_to_all } , all_rank_num_tokens: { all_rank_num_tokens } , all_rank_max_num_tokens: { all_rank_max_num_tokens } , use_dp_padding: { use_dp_padding } , repeating_info: { repeating_info } "
420
+ )
421
+
402
422
if self .layer_load_balancer and is_first_call :
403
423
self .layer_load_balancer .start_wait_gpu_stage ()
404
424
@@ -475,7 +495,7 @@ def forward_chunk(
475
495
self .dummy_allreduce ()
476
496
token_count = x .shape [0 ]
477
497
alltoall_info = None
478
- if is_last_call :
498
+ if self . layer_load_balancer and is_last_call :
479
499
loadbalancer_local_statistic_info = self .layer_load_balancer .get_local_statistic_tensor (
480
500
)
481
501
else :
@@ -650,7 +670,35 @@ def forward_chunk(
650
670
)
651
671
652
672
# Original fused_moe call (preserved as reference)
653
- final_hidden_states = torch .ops .trtllm .fused_moe (
673
+ # final_hidden_states = torch.ops.trtllm.fused_moe(
674
+ # x,
675
+ # token_selected_slots,
676
+ # token_final_scales,
677
+ # w3_w1_weight.view(weight_dtype),
678
+ # None, # w3_w1_bias
679
+ # w2_weight.view(weight_dtype),
680
+ # None, # w2_bias
681
+ # output_dtype,
682
+ # quant_scales=quant_scales,
683
+ # input_sf=x_sf,
684
+ # swizzled_input_sf=False,
685
+ # tp_size=self.tp_size,
686
+ # tp_rank=self.tp_rank,
687
+ # ep_size=ep_size,
688
+ # ep_rank=ep_rank,
689
+ # cluster_size=cluster_size,
690
+ # cluster_rank=cluster_rank,
691
+ # enable_alltoall=use_all_to_all,
692
+ # use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
693
+ # use_w4_group_scaling=use_w4_group_scaling,
694
+ # min_latency_mode=False,
695
+ # tune_max_num_tokens=self.tune_max_num_tokens,
696
+ # tuner_num_tokens=tuner_num_tokens,
697
+ # tuner_top_k=tuner_top_k,
698
+ # )
699
+
700
+ # Use the selected backend to compute MoE with the same parameters as fused_moe
701
+ final_hidden_states = self .moe_backend .run_moe (
654
702
x ,
655
703
token_selected_slots ,
656
704
token_final_scales ,
@@ -675,36 +723,13 @@ def forward_chunk(
675
723
tune_max_num_tokens = self .tune_max_num_tokens ,
676
724
tuner_num_tokens = tuner_num_tokens ,
677
725
tuner_top_k = tuner_top_k ,
726
+ module =
727
+ self , # Additional parameter for backend to access module properties
678
728
)
679
729
680
- # Use the selected backend to compute MoE with the same parameters as fused_moe
681
- # final_hidden_states = self.moe_backend.run_moe(
682
- # x,
683
- # token_selected_slots,
684
- # token_final_scales,
685
- # w3_w1_weight.view(weight_dtype),
686
- # None, # w3_w1_bias
687
- # w2_weight.view(weight_dtype),
688
- # None, # w2_bias
689
- # output_dtype,
690
- # quant_scales=quant_scales,
691
- # input_sf=x_sf,
692
- # swizzled_input_sf=False,
693
- # tp_size=self.tp_size,
694
- # tp_rank=self.tp_rank,
695
- # ep_size=ep_size,
696
- # ep_rank=ep_rank,
697
- # cluster_size=cluster_size,
698
- # cluster_rank=cluster_rank,
699
- # enable_alltoall=use_all_to_all,
700
- # use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
701
- # use_w4_group_scaling=use_w4_group_scaling,
702
- # min_latency_mode=False,
703
- # tune_max_num_tokens=self.tune_max_num_tokens,
704
- # tuner_num_tokens=tuner_num_tokens,
705
- # tuner_top_k=tuner_top_k,
706
- # module=self, # Additional parameter for backend to access module properties
707
- # )
730
+ print (
731
+ f"xxi shape 4 after moe backend : { getattr (x , 'shape' , None )} , final_hidden_states shape: { getattr (final_hidden_states , 'shape' , None )} , token_selected_slots shape: { getattr (token_selected_slots , 'shape' , None )} , token_final_scales shape: { getattr (token_final_scales , 'shape' , None )} , w3_w1_weight shape: { getattr (w3_w1_weight , 'shape' , None )} , w2_weight shape: { getattr (w2_weight , 'shape' , None )} , quant_scales: { getattr (quant_scales , 'shape' , None )} , input_sf: { getattr (x_sf , 'shape' , None )} , swizzled_input_sf: False, tp_size: { self .tp_size } , tp_rank: { self .tp_rank } , ep_size: { ep_size } , ep_rank: { ep_rank } , cluster_size: { cluster_size } , cluster_rank: { cluster_rank } , enable_alltoall: { use_all_to_all } , use_deepseek_fp8_block_scale: { use_deepseek_fp8_block_scale } , use_w4_group_scaling: { use_w4_group_scaling } , min_latency_mode: False, tune_max_num_tokens: { self .tune_max_num_tokens } , tuner_num_tokens: { tuner_num_tokens } , tuner_top_k: { tuner_top_k } "
732
+ )
708
733
709
734
if self .layer_load_balancer and is_last_call :
710
735
self .layer_load_balancer .start_set_cpu_stage ()
@@ -784,6 +809,10 @@ def forward(
784
809
all_rank_max_num_tokens = all_rank_max_num_tokens ,
785
810
use_dp_padding = use_dp_padding ,
786
811
repeating_info = (is_first_call , is_last_call ))
812
+ # 一行打印所有信息
813
+ print (
814
+ f"xxi x.shape: { getattr (x , 'shape' , None )} , use_all_to_all: { use_all_to_all } , all_rank_num_tokens: { all_rank_num_tokens } , all_rank_num_tokens_padded: { all_rank_num_tokens_padded } , all_rank_max_num_tokens: { all_rank_max_num_tokens } , use_dp_padding: { use_dp_padding } , outputs.shape: { getattr (outputs , 'shape' , None )} , use_dp_padding(again): { use_dp_padding } "
815
+ )
787
816
outputs = self .reducescatter_or_allreduce (
788
817
outputs ,
789
818
use_all_to_all ,
0 commit comments