55
66import torch
77
8- from vllm .model_executor .layers .fused_moe .batched_deep_gemm_moe import (
9- silu_mul_fp8_quant_deep_gemm as gold ,
10- )
118from vllm .model_executor .layers .fused_moe .batched_deep_gemm_moe import (
129 silu_mul_fp8_quant_deep_gemm_cuda ,
1310)
1411from vllm .platforms import current_platform
12+ from vllm .triton_utils import tl , triton
13+ from vllm .utils .deep_gemm import is_deep_gemm_e8m0_used
14+
15+
16+ @triton .jit
17+ def _silu_mul_fp8_quant_deep_gemm (
18+ # Pointers ------------------------------------------------------------
19+ input_ptr , # 16-bit activations (E, T, 2*H)
20+ y_q_ptr , # fp8 quantized activations (E, T, H)
21+ y_s_ptr , # 16-bit scales (E, T, G)
22+ counts_ptr , # int32 num tokens per expert (E)
23+ # Sizes ---------------------------------------------------------------
24+ H : tl .constexpr , # hidden dimension (per output)
25+ GROUP_SIZE : tl .constexpr , # elements per group (usually 128)
26+ # Strides for input (elements) ---------------------------------------
27+ stride_i_e ,
28+ stride_i_t ,
29+ stride_i_h ,
30+ # Strides for y_q (elements) -----------------------------------------
31+ stride_yq_e ,
32+ stride_yq_t ,
33+ stride_yq_h ,
34+ # Strides for y_s (elements) -----------------------------------------
35+ stride_ys_e ,
36+ stride_ys_t ,
37+ stride_ys_g ,
38+ # Stride for counts (elements)
39+ stride_counts_e ,
40+ # Numeric params ------------------------------------------------------
41+ eps : tl .constexpr ,
42+ fp8_min : tl .constexpr ,
43+ fp8_max : tl .constexpr ,
44+ use_ue8m0 : tl .constexpr ,
45+ # Meta ---------------------------------------------------------------
46+ BLOCK : tl .constexpr ,
47+ NUM_STAGES : tl .constexpr ,
48+ ):
49+ G = H // GROUP_SIZE
50+
51+ # map program id -> (e, g)
52+ pid = tl .program_id (0 )
53+ e = pid // G
54+ g = pid % G
55+
56+ e = e .to (tl .int64 )
57+ g = g .to (tl .int64 )
58+
59+ # number of valid tokens for this expert
60+ n_tokens = tl .load (counts_ptr + e * stride_counts_e ).to (tl .int64 )
61+
62+ cols = tl .arange (0 , BLOCK ).to (tl .int64 )
63+ mask = cols < BLOCK
64+
65+ base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
66+ base_gate_offset = base_input_offset + cols * stride_i_h
67+ base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
68+ base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
69+ base_ys_offset = e * stride_ys_e + g * stride_ys_g
70+
71+ for t in tl .range (0 , n_tokens , num_stages = NUM_STAGES ):
72+ gate = tl .load (
73+ input_ptr + base_gate_offset + t * stride_i_t , mask = mask , other = 0.0
74+ ).to (tl .float32 )
75+ up = tl .load (
76+ input_ptr + base_up_offset + t * stride_i_t , mask = mask , other = 0.0
77+ ).to (tl .float32 )
78+
79+ gate = gate * (1.0 / (1.0 + tl .exp (- gate )))
80+ y = gate * up
81+
82+ y_s = tl .maximum (tl .max (tl .abs (y )), eps ) / fp8_max
83+ if use_ue8m0 :
84+ y_s = tl .exp2 (tl .ceil (tl .log2 (y_s )))
85+
86+ y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
87+
88+ tl .store (y_q_ptr + base_yq_offset + t * stride_yq_t , y_q , mask = mask )
89+ tl .store (y_s_ptr + base_ys_offset + t * stride_ys_t , y_s )
90+
91+
92+ def gold (
93+ y : torch .Tensor , # (E, T, 2*H)
94+ tokens_per_expert : torch .Tensor , # (E,) number of valid tokens per expert
95+ num_parallel_tokens = 16 ,
96+ group_size : int = 128 ,
97+ eps : float = 1e-10 ,
98+ ) -> tuple [torch .Tensor , torch .Tensor ]:
99+ """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
100+
101+ y has shape (E, T, 2*H). The first half of the last dimension is
102+ silu-activated, multiplied by the second half, then quantized into FP8.
103+
104+ Returns `(y_q, y_s)` where
105+ * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
106+ * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
107+ """
108+ assert y .ndim == 3 , "y must be (E, T, 2*H)"
109+ E , T , H2 = y .shape
110+ assert H2 % 2 == 0 , "last dim of y must be even (2*H)"
111+ H = H2 // 2
112+ G = H // group_size
113+ assert H % group_size == 0 , "H must be divisible by group_size"
114+ assert tokens_per_expert .ndim == 1 and tokens_per_expert .shape [0 ] == E , (
115+ "tokens_per_expert must be shape (E,)"
116+ )
117+ tokens_per_expert = tokens_per_expert .to (device = y .device , dtype = torch .int32 )
118+
119+ # allocate outputs
120+ fp8_dtype = torch .float8_e4m3fn
121+ y_q = torch .empty ((E , T , H ), dtype = fp8_dtype , device = y .device )
122+
123+ # strides (elements)
124+ stride_i_e , stride_i_t , stride_i_h = y .stride ()
125+ stride_yq_e , stride_yq_t , stride_yq_h = y_q .stride ()
126+
127+ # desired scale strides (elements): (T*G, 1, T)
128+ stride_ys_e = T * G
129+ stride_ys_t = 1
130+ stride_ys_g = T
131+ y_s = torch .empty_strided (
132+ (E , T , G ),
133+ (stride_ys_e , stride_ys_t , stride_ys_g ),
134+ dtype = torch .float32 ,
135+ device = y .device ,
136+ )
137+
138+ stride_cnt_e = tokens_per_expert .stride ()[0 ]
139+
140+ # Static grid over experts and H-groups.
141+ # A loop inside the kernel handles the token dim
142+ grid = (E * G ,)
143+
144+ f_info = torch .finfo (fp8_dtype )
145+ fp8_max = f_info .max
146+ fp8_min = f_info .min
147+
148+ _silu_mul_fp8_quant_deep_gemm [grid ](
149+ y ,
150+ y_q ,
151+ y_s ,
152+ tokens_per_expert ,
153+ H ,
154+ group_size ,
155+ stride_i_e ,
156+ stride_i_t ,
157+ stride_i_h ,
158+ stride_yq_e ,
159+ stride_yq_t ,
160+ stride_yq_h ,
161+ stride_ys_e ,
162+ stride_ys_t ,
163+ stride_ys_g ,
164+ stride_cnt_e ,
165+ eps ,
166+ fp8_min ,
167+ fp8_max ,
168+ is_deep_gemm_e8m0_used (),
169+ BLOCK = group_size ,
170+ NUM_STAGES = 8 ,
171+ )
172+
173+ return y_q , y_s
15174
16175
17- def benchmark (k , E , T , H , G = 128 , runs = 100 ):
176+ def benchmark (k , E , T , H , num_parallel_tokens , G = 128 , runs = 100 ):
18177 current_platform .seed_everything (42 )
19178 y = torch .randn ((E , T , 2 * H ), dtype = torch .bfloat16 , device = "cuda" ).contiguous ()
20179 tokens_per_expert = torch .randint (
@@ -23,14 +182,14 @@ def benchmark(k, E, T, H, G=128, runs=100):
23182
24183 # Warmup
25184 for _ in range (20 ):
26- k (y , tokens_per_expert , group_size = G )
185+ k (y , tokens_per_expert , num_parallel_tokens = num_parallel_tokens , group_size = G )
27186 torch .cuda .synchronize ()
28187
29188 # Benchmark
30189 torch .cuda .synchronize ()
31190 start = time .perf_counter ()
32191 for _ in range (runs ):
33- k (y , tokens_per_expert , group_size = G )
192+ k (y , tokens_per_expert , num_parallel_tokens = num_parallel_tokens , group_size = G )
34193 torch .cuda .synchronize ()
35194
36195 avg_time = (time .perf_counter () - start ) / runs * 1000
@@ -54,38 +213,38 @@ def benchmark(k, E, T, H, G=128, runs=100):
54213 return avg_time , gflops , memory_bw
55214
56215
57- configs = [
58- # DeepSeekV3 Configs
59- (8 , 16 , 7168 ),
60- ( 8 , 32 , 7168 ),
61- ( 8 , 64 , 7168 ),
62- ( 8 , 128 , 7168 ),
63- ( 8 , 256 , 7168 ),
64- ( 8 , 512 , 7168 ),
65- ( 8 , 1024 , 7168 ),
66- ( 9 , 16 , 7168 ),
67- ( 9 , 32 , 7168 ),
68- ( 9 , 64 , 7168 ),
69- ( 9 , 128 , 7168 ),
70- ( 9 , 256 , 7168 ),
71- ( 9 , 512 , 7168 ),
72- ( 9 , 1024 , 7168 ),
73- ]
74-
75-
76- print ( f"GPU: { torch . cuda . get_device_name () } CUDA Kernel" )
77- print ( f" { 'Config' :<20 } { 'Time(ms)' :<10 } { 'GFLOPS' :<10 } { 'GB/s' :<10 } " )
78- print ( "-" * 50 )
79-
80- for E , T , H in configs :
81- time_ms , gflops , gbps = benchmark ( silu_mul_fp8_quant_deep_gemm_cuda , E , T , H )
82- print (f"E= { E :3d } ,T= { T :4d } ,H= { H :4d } { time_ms :8.3f } { gflops :8.1f } { gbps :8.1f } " )
83-
84-
85- print ( f"GPU: { torch . cuda . get_device_name () } Baseline" )
86- print ( f" { 'Config' :<20 } { 'Time(ms)' :<10 } { 'GFLOPS' :<10 } { 'GB/s' :<10 } " )
87- print ( "-" * 50 )
88-
89- for E , T , H in configs :
90- time_ms , gflops , gbps = benchmark ( gold , E , T , H )
91- print ( f"E= { E :3d } ,T= { T :4d } ,H= { H :4d } { time_ms :8.3f } { gflops :8.1f } { gbps :8.1f } " )
216+ def benchmark_full ():
217+ configs = [
218+ (8 , 32 , 1024 ),
219+ ( 16 , 64 , 2048 ),
220+ ( 32 , 128 , 4096 ),
221+ # DeepSeekV3 Configs
222+ ( 256 , 16 , 7168 ),
223+ ( 256 , 32 , 7168 ),
224+ ( 256 , 64 , 7168 ),
225+ ( 256 , 128 , 7168 ),
226+ ( 256 , 256 , 7168 ),
227+ ( 256 , 512 , 7168 ),
228+ ( 256 , 1024 , 7168 ),
229+ ]
230+
231+ print ( f"GPU: { torch . cuda . get_device_name () } CUDA Kernel" )
232+ print ( f" { 'Config' :<20 } { 'Time(ms)' :<10 } { 'GFLOPS' :<10 } { 'GB/s' :<10 } " )
233+ print ( "-" * 50 )
234+
235+ for E , T , H in configs :
236+ time_ms , gflops , gbps = benchmark (
237+ silu_mul_fp8_quant_deep_gemm_cuda , E , T , H , 16
238+ )
239+ print ( f"E= { E :3d } ,T= { T :4d } ,H= { H :4d } { time_ms :8.3f } { gflops :8.1f } { gbps :8.1f } " )
240+
241+ print (f"GPU: { torch . cuda . get_device_name () } Baseline " )
242+ print ( f" { 'Config' :<20 } { 'Time(ms)' :<10 } { 'GFLOPS' :<10 } { 'GB/s' :<10 } " )
243+ print ( "-" * 50 )
244+
245+ for E , T , H in configs :
246+ time_ms , gflops , gbps = benchmark ( gold , E , T , H , 16 )
247+ print ( f"E= { E :3d } ,T= { T :4d } ,H= { H :4d } { time_ms :8.3f } { gflops :8.1f } { gbps :8.1f } " )
248+
249+
250+ benchmark_full ( )
0 commit comments