Skip to content

Commit 91eb8b8

Browse files
committed
debug
Signed-off-by: xxi <[email protected]> modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py
1 parent 373e11f commit 91eb8b8

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -650,35 +650,7 @@ def forward_chunk(
650650
)
651651

652652
# Original fused_moe call (preserved as reference)
653-
final_hidden_states = torch.ops.trtllm.fused_moe(
654-
x,
655-
token_selected_slots,
656-
token_final_scales,
657-
w3_w1_weight.view(weight_dtype),
658-
None, # w3_w1_bias
659-
w2_weight.view(weight_dtype),
660-
None, # w2_bias
661-
output_dtype,
662-
quant_scales=quant_scales,
663-
input_sf=x_sf,
664-
swizzled_input_sf=False,
665-
tp_size=self.tp_size,
666-
tp_rank=self.tp_rank,
667-
ep_size=ep_size,
668-
ep_rank=ep_rank,
669-
cluster_size=cluster_size,
670-
cluster_rank=cluster_rank,
671-
enable_alltoall=use_all_to_all,
672-
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
673-
use_w4_group_scaling=use_w4_group_scaling,
674-
min_latency_mode=False,
675-
tune_max_num_tokens=self.tune_max_num_tokens,
676-
tuner_num_tokens=tuner_num_tokens,
677-
tuner_top_k=tuner_top_k,
678-
)
679-
680-
# Use the selected backend to compute MoE with the same parameters as fused_moe
681-
# final_hidden_states = self.moe_backend.run_moe(
653+
# final_hidden_states = torch.ops.trtllm.fused_moe(
682654
# x,
683655
# token_selected_slots,
684656
# token_final_scales,
@@ -703,9 +675,38 @@ def forward_chunk(
703675
# tune_max_num_tokens=self.tune_max_num_tokens,
704676
# tuner_num_tokens=tuner_num_tokens,
705677
# tuner_top_k=tuner_top_k,
706-
# module=self, # Additional parameter for backend to access module properties
707678
# )
708679

680+
# Use the selected backend to compute MoE with the same parameters as fused_moe
681+
final_hidden_states = self.moe_backend.run_moe(
682+
x,
683+
token_selected_slots,
684+
token_final_scales,
685+
w3_w1_weight.view(weight_dtype),
686+
None, # w3_w1_bias
687+
w2_weight.view(weight_dtype),
688+
None, # w2_bias
689+
output_dtype,
690+
quant_scales=quant_scales,
691+
input_sf=x_sf,
692+
swizzled_input_sf=False,
693+
tp_size=self.tp_size,
694+
tp_rank=self.tp_rank,
695+
ep_size=ep_size,
696+
ep_rank=ep_rank,
697+
cluster_size=cluster_size,
698+
cluster_rank=cluster_rank,
699+
enable_alltoall=use_all_to_all,
700+
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
701+
use_w4_group_scaling=use_w4_group_scaling,
702+
min_latency_mode=False,
703+
tune_max_num_tokens=self.tune_max_num_tokens,
704+
tuner_num_tokens=tuner_num_tokens,
705+
tuner_top_k=tuner_top_k,
706+
module=
707+
self, # Additional parameter for backend to access module properties
708+
)
709+
709710
if self.layer_load_balancer and is_last_call:
710711
self.layer_load_balancer.start_set_cpu_stage()
711712

tensorrt_llm/_torch/modules/fused_moe/moe_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def compute_moe(
9898
Computed MoE output tensor
9999
"""
100100

101-
@abstractmethod
102101
def run_moe(
103102
self,
104103
# Positional arguments (same order as torch.ops.trtllm.fused_moe)

0 commit comments

Comments
 (0)