15
15
from ...utils import AuxStreamType , EventType , Fp4QuantizedTensor
16
16
from .deep_ep_utils import buffer_pool , deep_ep_installed
17
17
from .interface import MoE
18
- from .moe_backend import MoEBackendSelection
18
+ from .moe_backend import MoEBackend , MoEBackendSelection
19
19
from .moe_load_balancer import get_moe_load_balancer
20
20
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod ,
21
21
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ,
@@ -234,8 +234,8 @@ def __init__(
234
234
self .enable_dummy_allreduce = os .environ .get (
235
235
"TRTLLM_ENABLE_DUMMY_ALLREDUCE" , "0" ) == "1"
236
236
237
- # Select MoE backend based on configuration
238
- self .moe_backend = None # Will be initialized after weights are created
237
+ # MoE backend will be lazily initialized when first accessed (see moe_backend property)
238
+ self ._moe_backend_impl = None
239
239
240
240
def _check_configs (self ):
241
241
assert self ._weights_created
@@ -365,8 +365,18 @@ def create_weights(self):
365
365
self ._weights_created = True
366
366
self ._check_configs ()
367
367
368
- # Initialize MoE backend after weights are created
369
- self .moe_backend = MoEBackendSelection .select_backend (self )
368
+ @property
369
+ def moe_backend_impl (self ) -> MoEBackend :
370
+ """
371
+ Lazily initialize and return the MoE backend.
372
+
373
+ The backend is selected based on hardware capabilities and quantization
374
+ configuration, which are only available after weights are created.
375
+ """
376
+ if self ._moe_backend_impl is None :
377
+ assert self ._weights_created , "Weights must be created before accessing moe_backend"
378
+ self ._moe_backend_impl = MoEBackendSelection .select_backend (self )
379
+ return self ._moe_backend_impl
370
380
371
381
def dummy_allreduce (self ):
372
382
"""
@@ -422,8 +432,6 @@ def forward_chunk(
422
432
if self .layer_load_balancer and is_first_call :
423
433
self .layer_load_balancer .start_wait_gpu_stage ()
424
434
425
- use_deepseek_fp8_block_scale = False
426
- use_w4_group_scaling = False
427
435
weight_dtype = self .w3_w1_weight .dtype
428
436
429
437
token_selected_experts , token_final_scales = self .routing_method .apply (
@@ -578,9 +586,8 @@ def forward_chunk(
578
586
x_sf = x_sf .view ((x_row , - 1 ))
579
587
580
588
elif self .has_deepseek_fp8_block_scales :
581
- use_deepseek_fp8_block_scale = True
589
+ pass
582
590
elif self .has_w4afp8 :
583
- use_w4_group_scaling = True
584
591
weight_dtype = torch .quint4x2
585
592
else :
586
593
raise ValueError (
@@ -603,12 +610,12 @@ def forward_chunk(
603
610
sizes = None if use_dp_padding else all_rank_num_tokens )
604
611
x_row = x .shape [0 ]
605
612
606
- ep_size = self .ep_size
607
- ep_rank = self .ep_rank
613
+ # ep_size = self.ep_size
614
+ # ep_rank = self.ep_rank
608
615
w3_w1_weight = self .w3_w1_weight
609
616
w2_weight = self .w2_weight
610
- cluster_size = self .cluster_size
611
- cluster_rank = self .cluster_rank
617
+ # cluster_size = self.cluster_size
618
+ # cluster_rank = self.cluster_rank
612
619
quant_scales = self .quant_scales
613
620
614
621
if use_postquant_alltoall :
@@ -697,8 +704,9 @@ def forward_chunk(
697
704
# tuner_top_k=tuner_top_k,
698
705
# )
699
706
700
- # Use the selected backend to compute MoE with the same parameters as fused_moe
701
- final_hidden_states = self .moe_backend .run_moe (
707
+ # Use backend interface with module as first parameter for automatic configuration extraction
708
+ final_hidden_states = self .moe_backend_impl .run_moe (
709
+ self , # Module as first parameter
702
710
x ,
703
711
token_selected_slots ,
704
712
token_final_scales ,
@@ -710,21 +718,11 @@ def forward_chunk(
710
718
quant_scales = quant_scales ,
711
719
input_sf = x_sf ,
712
720
swizzled_input_sf = False ,
713
- tp_size = self .tp_size ,
714
- tp_rank = self .tp_rank ,
715
- ep_size = ep_size ,
716
- ep_rank = ep_rank ,
717
- cluster_size = cluster_size ,
718
- cluster_rank = cluster_rank ,
719
- enable_alltoall = use_all_to_all ,
720
- use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
721
- use_w4_group_scaling = use_w4_group_scaling ,
721
+ # Only need to pass runtime-variable parameters
722
722
min_latency_mode = False ,
723
- tune_max_num_tokens = self . tune_max_num_tokens ,
723
+ use_fused_finalize = True ,
724
724
tuner_num_tokens = tuner_num_tokens ,
725
725
tuner_top_k = tuner_top_k ,
726
- module =
727
- self , # Additional parameter for backend to access module properties
728
726
)
729
727
730
728
# print(
0 commit comments