6
6
import torch
7
7
8
8
from vllm .model_executor .layers .fused_moe .batched_deep_gemm_moe import (
9
- silu_mul_fp8_quant_deep_gemm ,
9
+ silu_mul_fp8_quant_deep_gemm_cuda ,
10
10
)
11
11
from vllm .platforms import current_platform
12
+ from vllm .triton_utils import tl , triton
13
+ from vllm .utils .deep_gemm import is_deep_gemm_e8m0_used
12
14
13
15
14
- def benchmark (E , T , H , G = 128 , runs = 50 ):
16
+ @triton .jit
17
+ def _silu_mul_fp8_quant_deep_gemm (
18
+ # Pointers ------------------------------------------------------------
19
+ input_ptr , # 16-bit activations (E, T, 2*H)
20
+ y_q_ptr , # fp8 quantized activations (E, T, H)
21
+ y_s_ptr , # 16-bit scales (E, T, G)
22
+ counts_ptr , # int32 num tokens per expert (E)
23
+ # Sizes ---------------------------------------------------------------
24
+ H : tl .constexpr , # hidden dimension (per output)
25
+ GROUP_SIZE : tl .constexpr , # elements per group (usually 128)
26
+ # Strides for input (elements) ---------------------------------------
27
+ stride_i_e ,
28
+ stride_i_t ,
29
+ stride_i_h ,
30
+ # Strides for y_q (elements) -----------------------------------------
31
+ stride_yq_e ,
32
+ stride_yq_t ,
33
+ stride_yq_h ,
34
+ # Strides for y_s (elements) -----------------------------------------
35
+ stride_ys_e ,
36
+ stride_ys_t ,
37
+ stride_ys_g ,
38
+ # Stride for counts (elements)
39
+ stride_counts_e ,
40
+ # Numeric params ------------------------------------------------------
41
+ eps : tl .constexpr ,
42
+ fp8_min : tl .constexpr ,
43
+ fp8_max : tl .constexpr ,
44
+ use_ue8m0 : tl .constexpr ,
45
+ # Meta ---------------------------------------------------------------
46
+ BLOCK : tl .constexpr ,
47
+ NUM_STAGES : tl .constexpr ,
48
+ ):
49
+ G = H // GROUP_SIZE
50
+
51
+ # map program id -> (e, g)
52
+ pid = tl .program_id (0 )
53
+ e = pid // G
54
+ g = pid % G
55
+
56
+ e = e .to (tl .int64 )
57
+ g = g .to (tl .int64 )
58
+
59
+ # number of valid tokens for this expert
60
+ n_tokens = tl .load (counts_ptr + e * stride_counts_e ).to (tl .int64 )
61
+
62
+ cols = tl .arange (0 , BLOCK ).to (tl .int64 )
63
+ mask = cols < BLOCK
64
+
65
+ base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
66
+ base_gate_offset = base_input_offset + cols * stride_i_h
67
+ base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
68
+ base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
69
+ base_ys_offset = e * stride_ys_e + g * stride_ys_g
70
+
71
+ for t in tl .range (0 , n_tokens , num_stages = NUM_STAGES ):
72
+ gate = tl .load (
73
+ input_ptr + base_gate_offset + t * stride_i_t , mask = mask , other = 0.0
74
+ ).to (tl .float32 )
75
+ up = tl .load (
76
+ input_ptr + base_up_offset + t * stride_i_t , mask = mask , other = 0.0
77
+ ).to (tl .float32 )
78
+
79
+ gate = gate * (1.0 / (1.0 + tl .exp (- gate )))
80
+ y = gate * up
81
+
82
+ y_s = tl .maximum (tl .max (tl .abs (y )), eps ) / fp8_max
83
+ if use_ue8m0 :
84
+ y_s = tl .exp2 (tl .ceil (tl .log2 (y_s )))
85
+
86
+ y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
87
+
88
+ tl .store (y_q_ptr + base_yq_offset + t * stride_yq_t , y_q , mask = mask )
89
+ tl .store (y_s_ptr + base_ys_offset + t * stride_ys_t , y_s )
90
+
91
+
92
+ def gold (
93
+ y : torch .Tensor , # (E, T, 2*H)
94
+ tokens_per_expert : torch .Tensor , # (E,) number of valid tokens per expert
95
+ num_parallel_tokens = 16 ,
96
+ group_size : int = 128 ,
97
+ eps : float = 1e-10 ,
98
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
99
+ """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
100
+
101
+ y has shape (E, T, 2*H). The first half of the last dimension is
102
+ silu-activated, multiplied by the second half, then quantized into FP8.
103
+
104
+ Returns `(y_q, y_s)` where
105
+ * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
106
+ * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
107
+ """
108
+ assert y .ndim == 3 , "y must be (E, T, 2*H)"
109
+ E , T , H2 = y .shape
110
+ assert H2 % 2 == 0 , "last dim of y must be even (2*H)"
111
+ H = H2 // 2
112
+ G = H // group_size
113
+ assert H % group_size == 0 , "H must be divisible by group_size"
114
+ assert tokens_per_expert .ndim == 1 and tokens_per_expert .shape [0 ] == E , (
115
+ "tokens_per_expert must be shape (E,)"
116
+ )
117
+ tokens_per_expert = tokens_per_expert .to (device = y .device , dtype = torch .int32 )
118
+
119
+ # allocate outputs
120
+ fp8_dtype = torch .float8_e4m3fn
121
+ y_q = torch .empty ((E , T , H ), dtype = fp8_dtype , device = y .device )
122
+
123
+ # strides (elements)
124
+ stride_i_e , stride_i_t , stride_i_h = y .stride ()
125
+ stride_yq_e , stride_yq_t , stride_yq_h = y_q .stride ()
126
+
127
+ # desired scale strides (elements): (T*G, 1, T)
128
+ stride_ys_e = T * G
129
+ stride_ys_t = 1
130
+ stride_ys_g = T
131
+ y_s = torch .empty_strided (
132
+ (E , T , G ),
133
+ (stride_ys_e , stride_ys_t , stride_ys_g ),
134
+ dtype = torch .float32 ,
135
+ device = y .device ,
136
+ )
137
+
138
+ stride_cnt_e = tokens_per_expert .stride ()[0 ]
139
+
140
+ # Static grid over experts and H-groups.
141
+ # A loop inside the kernel handles the token dim
142
+ grid = (E * G ,)
143
+
144
+ f_info = torch .finfo (fp8_dtype )
145
+ fp8_max = f_info .max
146
+ fp8_min = f_info .min
147
+
148
+ _silu_mul_fp8_quant_deep_gemm [grid ](
149
+ y ,
150
+ y_q ,
151
+ y_s ,
152
+ tokens_per_expert ,
153
+ H ,
154
+ group_size ,
155
+ stride_i_e ,
156
+ stride_i_t ,
157
+ stride_i_h ,
158
+ stride_yq_e ,
159
+ stride_yq_t ,
160
+ stride_yq_h ,
161
+ stride_ys_e ,
162
+ stride_ys_t ,
163
+ stride_ys_g ,
164
+ stride_cnt_e ,
165
+ eps ,
166
+ fp8_min ,
167
+ fp8_max ,
168
+ is_deep_gemm_e8m0_used (),
169
+ BLOCK = group_size ,
170
+ NUM_STAGES = 8 ,
171
+ )
172
+
173
+ return y_q , y_s
174
+
175
+
176
+ def benchmark (k , E , T , H , num_parallel_tokens , G = 128 , runs = 100 ):
15
177
current_platform .seed_everything (42 )
16
- y = torch .randn ((E , T , 2 * H ), dtype = torch .bfloat16 , device = "cuda" )
178
+ y = torch .randn ((E , T , 2 * H ), dtype = torch .bfloat16 , device = "cuda" ). contiguous ()
17
179
tokens_per_expert = torch .randint (
18
180
T // 2 , T , size = (E ,), dtype = torch .int32 , device = "cuda"
19
181
)
20
182
21
183
# Warmup
22
- for _ in range (10 ):
23
- silu_mul_fp8_quant_deep_gemm (y , tokens_per_expert , group_size = G )
184
+ for _ in range (20 ):
185
+ k (y , tokens_per_expert , num_parallel_tokens = num_parallel_tokens , group_size = G )
24
186
torch .cuda .synchronize ()
25
187
26
188
# Benchmark
27
189
torch .cuda .synchronize ()
28
190
start = time .perf_counter ()
29
191
for _ in range (runs ):
30
- silu_mul_fp8_quant_deep_gemm (y , tokens_per_expert , group_size = G )
192
+ k (y , tokens_per_expert , num_parallel_tokens = num_parallel_tokens , group_size = G )
31
193
torch .cuda .synchronize ()
32
194
33
195
avg_time = (time .perf_counter () - start ) / runs * 1000
@@ -51,27 +213,38 @@ def benchmark(E, T, H, G=128, runs=50):
51
213
return avg_time , gflops , memory_bw
52
214
53
215
54
- configs = [
55
- (8 , 32 , 1024 ),
56
- (16 , 64 , 2048 ),
57
- (32 , 128 , 4096 ),
58
- # DeepSeekV3 Configs
59
- (256 , 16 , 7168 ),
60
- (256 , 32 , 7168 ),
61
- (256 , 64 , 7168 ),
62
- (256 , 128 , 7168 ),
63
- (256 , 256 , 7168 ),
64
- (256 , 512 , 7168 ),
65
- (256 , 1024 , 7168 ),
66
- ]
67
-
68
- print (f"GPU: { torch .cuda .get_device_name ()} " )
69
- print (f"{ 'Config' :<20} { 'Time(ms)' :<10} { 'GFLOPS' :<10} { 'GB/s' :<10} " )
70
- print ("-" * 50 )
71
-
72
- for E , T , H in configs :
73
- try :
74
- time_ms , gflops , gbps = benchmark (E , T , H )
216
+ def benchmark_full ():
217
+ configs = [
218
+ (8 , 32 , 1024 ),
219
+ (16 , 64 , 2048 ),
220
+ (32 , 128 , 4096 ),
221
+ # DeepSeekV3 Configs
222
+ (256 , 16 , 7168 ),
223
+ (256 , 32 , 7168 ),
224
+ (256 , 64 , 7168 ),
225
+ (256 , 128 , 7168 ),
226
+ (256 , 256 , 7168 ),
227
+ (256 , 512 , 7168 ),
228
+ (256 , 1024 , 7168 ),
229
+ ]
230
+
231
+ print (f"GPU: { torch .cuda .get_device_name ()} CUDA Kernel" )
232
+ print (f"{ 'Config' :<20} { 'Time(ms)' :<10} { 'GFLOPS' :<10} { 'GB/s' :<10} " )
233
+ print ("-" * 50 )
234
+
235
+ for E , T , H in configs :
236
+ time_ms , gflops , gbps = benchmark (
237
+ silu_mul_fp8_quant_deep_gemm_cuda , E , T , H , 16
238
+ )
239
+ print (f"E={ E :3d} ,T={ T :4d} ,H={ H :4d} { time_ms :8.3f} { gflops :8.1f} { gbps :8.1f} " )
240
+
241
+ print (f"GPU: { torch .cuda .get_device_name ()} Baseline" )
242
+ print (f"{ 'Config' :<20} { 'Time(ms)' :<10} { 'GFLOPS' :<10} { 'GB/s' :<10} " )
243
+ print ("-" * 50 )
244
+
245
+ for E , T , H in configs :
246
+ time_ms , gflops , gbps = benchmark (gold , E , T , H , 16 )
75
247
print (f"E={ E :3d} ,T={ T :4d} ,H={ H :4d} { time_ms :8.3f} { gflops :8.1f} { gbps :8.1f} " )
76
- except Exception :
77
- print (f"E={ E :3d} ,T={ T :4d} ,H={ H :4d} FAILED" )
248
+
249
+
250
+ benchmark_full ()
0 commit comments