@@ -650,35 +650,7 @@ def forward_chunk(
650
650
)
651
651
652
652
# Original fused_moe call (preserved as reference)
653
- final_hidden_states = torch .ops .trtllm .fused_moe (
654
- x ,
655
- token_selected_slots ,
656
- token_final_scales ,
657
- w3_w1_weight .view (weight_dtype ),
658
- None , # w3_w1_bias
659
- w2_weight .view (weight_dtype ),
660
- None , # w2_bias
661
- output_dtype ,
662
- quant_scales = quant_scales ,
663
- input_sf = x_sf ,
664
- swizzled_input_sf = False ,
665
- tp_size = self .tp_size ,
666
- tp_rank = self .tp_rank ,
667
- ep_size = ep_size ,
668
- ep_rank = ep_rank ,
669
- cluster_size = cluster_size ,
670
- cluster_rank = cluster_rank ,
671
- enable_alltoall = use_all_to_all ,
672
- use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
673
- use_w4_group_scaling = use_w4_group_scaling ,
674
- min_latency_mode = False ,
675
- tune_max_num_tokens = self .tune_max_num_tokens ,
676
- tuner_num_tokens = tuner_num_tokens ,
677
- tuner_top_k = tuner_top_k ,
678
- )
679
-
680
- # Use the selected backend to compute MoE with the same parameters as fused_moe
681
- # final_hidden_states = self.moe_backend.run_moe(
653
+ # final_hidden_states = torch.ops.trtllm.fused_moe(
682
654
# x,
683
655
# token_selected_slots,
684
656
# token_final_scales,
@@ -703,9 +675,38 @@ def forward_chunk(
703
675
# tune_max_num_tokens=self.tune_max_num_tokens,
704
676
# tuner_num_tokens=tuner_num_tokens,
705
677
# tuner_top_k=tuner_top_k,
706
- # module=self, # Additional parameter for backend to access module properties
707
678
# )
708
679
680
+ # Use the selected backend to compute MoE with the same parameters as fused_moe
681
+ final_hidden_states = self .moe_backend .run_moe (
682
+ x ,
683
+ token_selected_slots ,
684
+ token_final_scales ,
685
+ w3_w1_weight .view (weight_dtype ),
686
+ None , # w3_w1_bias
687
+ w2_weight .view (weight_dtype ),
688
+ None , # w2_bias
689
+ output_dtype ,
690
+ quant_scales = quant_scales ,
691
+ input_sf = x_sf ,
692
+ swizzled_input_sf = False ,
693
+ tp_size = self .tp_size ,
694
+ tp_rank = self .tp_rank ,
695
+ ep_size = ep_size ,
696
+ ep_rank = ep_rank ,
697
+ cluster_size = cluster_size ,
698
+ cluster_rank = cluster_rank ,
699
+ enable_alltoall = use_all_to_all ,
700
+ use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
701
+ use_w4_group_scaling = use_w4_group_scaling ,
702
+ min_latency_mode = False ,
703
+ tune_max_num_tokens = self .tune_max_num_tokens ,
704
+ tuner_num_tokens = tuner_num_tokens ,
705
+ tuner_top_k = tuner_top_k ,
706
+ module =
707
+ self , # Additional parameter for backend to access module properties
708
+ )
709
+
709
710
if self .layer_load_balancer and is_last_call :
710
711
self .layer_load_balancer .start_set_cpu_stage ()
711
712
0 commit comments