@@ -588,43 +588,26 @@ def forward_chunk(
588
588
x_sf = swizzle_sf (x_sf , x .shape [0 ], x .shape [1 ] * 2 ,
589
589
self .scaling_vector_size )
590
590
elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
591
- assert x_sf is not None and self .has_nvfp4
592
591
token_num = x_row
593
592
hidden_size = x_col
593
+ assert x_sf is not None and self .has_nvfp4
594
594
assert hidden_size % 32 == 0
595
- x_sf_dtype = x_sf .dtype
596
- x_dtype = x .dtype
597
- assert x_sf_dtype == torch .uint8 and x_dtype == torch .uint8
598
- x_sf = x_sf .view (torch .bfloat16 )
595
+ assert x .dtype == torch .uint8 and x_sf .dtype == torch .uint8
599
596
assert x_sf .shape [0 ] == token_num and x_sf .shape [
600
- 1 ] == hidden_size // 16 // 2
601
- x = x .view (torch .bfloat16 )
602
- assert x .shape [0 ] == token_num and x .shape [1 ] == hidden_size // 4
603
- # DeepEP LL dispatch only supports bf16 tensors with a hidden size of 2560, 4096, 5120, or 7168 as input. A hidden size of 2560 is sufficient to accommodate packed FP4 data.
604
- packed_hidden_size = 2560
605
- assert x .shape [1 ] + x_sf .shape [1 ] <= packed_hidden_size
606
- fp4_packed_tensor = torch .empty ((token_num , packed_hidden_size ),
607
- dtype = torch .bfloat16 ,
608
- device = x .device )
609
- fp4_packed_tensor [:, :x .shape [1 ]] = x
610
- fp4_packed_tensor [:,
611
- x .shape [1 ]:x .shape [1 ] + x_sf .shape [1 ]] = x_sf
597
+ 1 ] == hidden_size // 16
598
+ assert x .shape [0 ] == token_num and x .shape [1 ] == hidden_size // 2
612
599
613
600
deep_ep_topk_idx = token_selected_slots
614
601
deep_ep_topk_weights = token_final_scales
615
602
616
603
assert all_rank_max_num_tokens <= self .deep_ep_max_num_tokens
617
- fp4_packed_tensor , recv_expert_count , deep_ep_handle = \
618
- self .deep_ep_buffer .low_latency_dispatch (fp4_packed_tensor , deep_ep_topk_idx , all_rank_max_num_tokens , self .num_slots )
619
- deep_ep_handle = list (deep_ep_handle )
620
- deep_ep_handle [3 ] = hidden_size
621
- deep_ep_handle = tuple (deep_ep_handle )
622
-
623
- assert fp4_packed_tensor .ndim == 3 and fp4_packed_tensor .shape [
624
- 2 ] == packed_hidden_size
625
- x_sf = fp4_packed_tensor [:, :, x .shape [1 ]:x .shape [1 ] +
626
- x_sf .shape [1 ]].contiguous ()
627
- x = fp4_packed_tensor [:, :, :x .shape [1 ]].contiguous ()
604
+ x , x_sf , recv_expert_count , deep_ep_handle = \
605
+ self .deep_ep_buffer .low_latency_dispatch_fp4 (x , x_sf , deep_ep_topk_idx , all_rank_max_num_tokens , self .num_slots )
606
+ assert x .dtype == torch .uint8 and x_sf .dtype == torch .uint8
607
+ assert x .dim () == 3 and x_sf .dim () == 3
608
+ assert x .shape [2 ] == hidden_size // 2 and x_sf .shape [
609
+ 2 ] == hidden_size // 16
610
+
628
611
mask = torch .arange (
629
612
x .shape [1 ], dtype = torch .int32 , device = x .device ).expand (
630
613
x .shape [0 ], x .shape [1 ]) < recv_expert_count .unsqueeze (1 )
@@ -634,9 +617,9 @@ def forward_chunk(
634
617
x .shape [0 ] * (self .mapping .moe_ep_rank + 1 ),
635
618
dtype = torch .int32 ,
636
619
device = x .device ).unsqueeze (1 ), self .num_slots )
637
- x = x .reshape (x .shape [0 ] * x .shape [1 ], x .shape [2 ]). view ( x_dtype )
620
+ x = x .reshape (x .shape [0 ] * x .shape [1 ], x .shape [2 ])
638
621
x_sf = x_sf .reshape (x_sf .shape [0 ] * x_sf .shape [1 ],
639
- x_sf .shape [2 ]). view ( x_sf_dtype )
622
+ x_sf .shape [2 ])
640
623
x_sf = swizzle_sf (x_sf , x .shape [0 ], x .shape [1 ] * 2 ,
641
624
self .scaling_vector_size )
642
625
token_selected_slots = token_selected_slots .view (x .shape [0 ], 1 )
0 commit comments