Skip to content

Commit bfc5496

Browse files
xxi-nvXingFei Xi
authored andcommitted
refactor
Signed-off-by: xxi <[email protected]> modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tests/unittest/_torch/modules/test_fused_moe.py
1 parent 15d9841 commit bfc5496

File tree

3 files changed

+350
-292
lines changed

3 files changed

+350
-292
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
1616
from .deep_ep_utils import buffer_pool, deep_ep_installed
1717
from .interface import MoE
18-
from .moe_backend import MoEBackendSelection
18+
from .moe_backend import MoEBackend, MoEBackendSelection
1919
from .moe_load_balancer import get_moe_load_balancer
2020
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
2121
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
@@ -234,8 +234,8 @@ def __init__(
234234
self.enable_dummy_allreduce = os.environ.get(
235235
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
236236

237-
# Select MoE backend based on configuration
238-
self.moe_backend = None # Will be initialized after weights are created
237+
# MoE backend will be lazily initialized when first accessed (see moe_backend property)
238+
self._moe_backend_impl = None
239239

240240
def _check_configs(self):
241241
assert self._weights_created
@@ -365,8 +365,18 @@ def create_weights(self):
365365
self._weights_created = True
366366
self._check_configs()
367367

368-
# Initialize MoE backend after weights are created
369-
self.moe_backend = MoEBackendSelection.select_backend(self)
368+
@property
369+
def moe_backend_impl(self) -> MoEBackend:
370+
"""
371+
Lazily initialize and return the MoE backend.
372+
373+
The backend is selected based on hardware capabilities and quantization
374+
configuration, which are only available after weights are created.
375+
"""
376+
if self._moe_backend_impl is None:
377+
assert self._weights_created, "Weights must be created before accessing moe_backend"
378+
self._moe_backend_impl = MoEBackendSelection.select_backend(self)
379+
return self._moe_backend_impl
370380

371381
def dummy_allreduce(self):
372382
"""
@@ -422,8 +432,6 @@ def forward_chunk(
422432
if self.layer_load_balancer and is_first_call:
423433
self.layer_load_balancer.start_wait_gpu_stage()
424434

425-
use_deepseek_fp8_block_scale = False
426-
use_w4_group_scaling = False
427435
weight_dtype = self.w3_w1_weight.dtype
428436

429437
token_selected_experts, token_final_scales = self.routing_method.apply(
@@ -578,9 +586,8 @@ def forward_chunk(
578586
x_sf = x_sf.view((x_row, -1))
579587

580588
elif self.has_deepseek_fp8_block_scales:
581-
use_deepseek_fp8_block_scale = True
589+
pass
582590
elif self.has_w4afp8:
583-
use_w4_group_scaling = True
584591
weight_dtype = torch.quint4x2
585592
else:
586593
raise ValueError(
@@ -603,12 +610,12 @@ def forward_chunk(
603610
sizes=None if use_dp_padding else all_rank_num_tokens)
604611
x_row = x.shape[0]
605612

606-
ep_size = self.ep_size
607-
ep_rank = self.ep_rank
613+
# ep_size = self.ep_size
614+
# ep_rank = self.ep_rank
608615
w3_w1_weight = self.w3_w1_weight
609616
w2_weight = self.w2_weight
610-
cluster_size = self.cluster_size
611-
cluster_rank = self.cluster_rank
617+
# cluster_size = self.cluster_size
618+
# cluster_rank = self.cluster_rank
612619
quant_scales = self.quant_scales
613620

614621
if use_postquant_alltoall:
@@ -697,8 +704,9 @@ def forward_chunk(
697704
# tuner_top_k=tuner_top_k,
698705
# )
699706

700-
# Use the selected backend to compute MoE with the same parameters as fused_moe
701-
final_hidden_states = self.moe_backend.run_moe(
707+
# Use backend interface with module as first parameter for automatic configuration extraction
708+
final_hidden_states = self.moe_backend_impl.run_moe(
709+
self, # Module as first parameter
702710
x,
703711
token_selected_slots,
704712
token_final_scales,
@@ -710,21 +718,11 @@ def forward_chunk(
710718
quant_scales=quant_scales,
711719
input_sf=x_sf,
712720
swizzled_input_sf=False,
713-
tp_size=self.tp_size,
714-
tp_rank=self.tp_rank,
715-
ep_size=ep_size,
716-
ep_rank=ep_rank,
717-
cluster_size=cluster_size,
718-
cluster_rank=cluster_rank,
719-
enable_alltoall=use_all_to_all,
720-
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
721-
use_w4_group_scaling=use_w4_group_scaling,
721+
# Only need to pass runtime-variable parameters
722722
min_latency_mode=False,
723-
tune_max_num_tokens=self.tune_max_num_tokens,
723+
use_fused_finalize=True,
724724
tuner_num_tokens=tuner_num_tokens,
725725
tuner_top_k=tuner_top_k,
726-
module=
727-
self, # Additional parameter for backend to access module properties
728726
)
729727

730728
# print(

0 commit comments

Comments
 (0)