Skip to content

Commit 386a373

Browse files
committed
Move triton code to benchmark script
Signed-off-by: elvircrn <[email protected]>
1 parent 44b40d2 commit 386a373

File tree

3 files changed

+200
-373
lines changed

3 files changed

+200
-373
lines changed

benchmarks/kernels/benchmark_silu_mul_fp8_quant.py

Lines changed: 200 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,175 @@
55

66
import torch
77

8-
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
9-
silu_mul_fp8_quant_deep_gemm as gold,
10-
)
118
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
129
silu_mul_fp8_quant_deep_gemm_cuda,
1310
)
1411
from vllm.platforms import current_platform
12+
from vllm.triton_utils import tl, triton
13+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
14+
15+
16+
@triton.jit
17+
def _silu_mul_fp8_quant_deep_gemm(
18+
# Pointers ------------------------------------------------------------
19+
input_ptr, # 16-bit activations (E, T, 2*H)
20+
y_q_ptr, # fp8 quantized activations (E, T, H)
21+
y_s_ptr, # 16-bit scales (E, T, G)
22+
counts_ptr, # int32 num tokens per expert (E)
23+
# Sizes ---------------------------------------------------------------
24+
H: tl.constexpr, # hidden dimension (per output)
25+
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
26+
# Strides for input (elements) ---------------------------------------
27+
stride_i_e,
28+
stride_i_t,
29+
stride_i_h,
30+
# Strides for y_q (elements) -----------------------------------------
31+
stride_yq_e,
32+
stride_yq_t,
33+
stride_yq_h,
34+
# Strides for y_s (elements) -----------------------------------------
35+
stride_ys_e,
36+
stride_ys_t,
37+
stride_ys_g,
38+
# Stride for counts (elements)
39+
stride_counts_e,
40+
# Numeric params ------------------------------------------------------
41+
eps: tl.constexpr,
42+
fp8_min: tl.constexpr,
43+
fp8_max: tl.constexpr,
44+
use_ue8m0: tl.constexpr,
45+
# Meta ---------------------------------------------------------------
46+
BLOCK: tl.constexpr,
47+
NUM_STAGES: tl.constexpr,
48+
):
49+
G = H // GROUP_SIZE
50+
51+
# map program id -> (e, g)
52+
pid = tl.program_id(0)
53+
e = pid // G
54+
g = pid % G
55+
56+
e = e.to(tl.int64)
57+
g = g.to(tl.int64)
58+
59+
# number of valid tokens for this expert
60+
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
61+
62+
cols = tl.arange(0, BLOCK).to(tl.int64)
63+
mask = cols < BLOCK
64+
65+
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
66+
base_gate_offset = base_input_offset + cols * stride_i_h
67+
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
68+
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
69+
base_ys_offset = e * stride_ys_e + g * stride_ys_g
70+
71+
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
72+
gate = tl.load(
73+
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
74+
).to(tl.float32)
75+
up = tl.load(
76+
input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0
77+
).to(tl.float32)
78+
79+
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
80+
y = gate * up
81+
82+
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
83+
if use_ue8m0:
84+
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
85+
86+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
87+
88+
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
89+
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
90+
91+
92+
def gold(
93+
y: torch.Tensor, # (E, T, 2*H)
94+
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
95+
num_parallel_tokens=16,
96+
group_size: int = 128,
97+
eps: float = 1e-10,
98+
) -> tuple[torch.Tensor, torch.Tensor]:
99+
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
100+
101+
y has shape (E, T, 2*H). The first half of the last dimension is
102+
silu-activated, multiplied by the second half, then quantized into FP8.
103+
104+
Returns `(y_q, y_s)` where
105+
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
106+
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
107+
"""
108+
assert y.ndim == 3, "y must be (E, T, 2*H)"
109+
E, T, H2 = y.shape
110+
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
111+
H = H2 // 2
112+
G = H // group_size
113+
assert H % group_size == 0, "H must be divisible by group_size"
114+
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, (
115+
"tokens_per_expert must be shape (E,)"
116+
)
117+
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
118+
119+
# allocate outputs
120+
fp8_dtype = torch.float8_e4m3fn
121+
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
122+
123+
# strides (elements)
124+
stride_i_e, stride_i_t, stride_i_h = y.stride()
125+
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
126+
127+
# desired scale strides (elements): (T*G, 1, T)
128+
stride_ys_e = T * G
129+
stride_ys_t = 1
130+
stride_ys_g = T
131+
y_s = torch.empty_strided(
132+
(E, T, G),
133+
(stride_ys_e, stride_ys_t, stride_ys_g),
134+
dtype=torch.float32,
135+
device=y.device,
136+
)
137+
138+
stride_cnt_e = tokens_per_expert.stride()[0]
139+
140+
# Static grid over experts and H-groups.
141+
# A loop inside the kernel handles the token dim
142+
grid = (E * G,)
143+
144+
f_info = torch.finfo(fp8_dtype)
145+
fp8_max = f_info.max
146+
fp8_min = f_info.min
147+
148+
_silu_mul_fp8_quant_deep_gemm[grid](
149+
y,
150+
y_q,
151+
y_s,
152+
tokens_per_expert,
153+
H,
154+
group_size,
155+
stride_i_e,
156+
stride_i_t,
157+
stride_i_h,
158+
stride_yq_e,
159+
stride_yq_t,
160+
stride_yq_h,
161+
stride_ys_e,
162+
stride_ys_t,
163+
stride_ys_g,
164+
stride_cnt_e,
165+
eps,
166+
fp8_min,
167+
fp8_max,
168+
is_deep_gemm_e8m0_used(),
169+
BLOCK=group_size,
170+
NUM_STAGES=8,
171+
)
172+
173+
return y_q, y_s
15174

16175

17-
def benchmark(k, E, T, H, G=128, runs=100):
176+
def benchmark(k, E, T, H, num_parallel_tokens, G=128, runs=100):
18177
current_platform.seed_everything(42)
19178
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
20179
tokens_per_expert = torch.randint(
@@ -23,14 +182,14 @@ def benchmark(k, E, T, H, G=128, runs=100):
23182

24183
# Warmup
25184
for _ in range(20):
26-
k(y, tokens_per_expert, group_size=G)
185+
k(y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G)
27186
torch.cuda.synchronize()
28187

29188
# Benchmark
30189
torch.cuda.synchronize()
31190
start = time.perf_counter()
32191
for _ in range(runs):
33-
k(y, tokens_per_expert, group_size=G)
192+
k(y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G)
34193
torch.cuda.synchronize()
35194

36195
avg_time = (time.perf_counter() - start) / runs * 1000
@@ -54,38 +213,38 @@ def benchmark(k, E, T, H, G=128, runs=100):
54213
return avg_time, gflops, memory_bw
55214

56215

57-
configs = [
58-
# DeepSeekV3 Configs
59-
(8, 16, 7168),
60-
(8, 32, 7168),
61-
(8, 64, 7168),
62-
(8, 128, 7168),
63-
(8, 256, 7168),
64-
(8, 512, 7168),
65-
(8, 1024, 7168),
66-
(9, 16, 7168),
67-
(9, 32, 7168),
68-
(9, 64, 7168),
69-
(9, 128, 7168),
70-
(9, 256, 7168),
71-
(9, 512, 7168),
72-
(9, 1024, 7168),
73-
]
74-
75-
76-
print(f"GPU: {torch.cuda.get_device_name()} CUDA Kernel")
77-
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
78-
print("-" * 50)
79-
80-
for E, T, H in configs:
81-
time_ms, gflops, gbps = benchmark(silu_mul_fp8_quant_deep_gemm_cuda, E, T, H)
82-
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
83-
84-
85-
print(f"GPU: {torch.cuda.get_device_name()} Baseline")
86-
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
87-
print("-" * 50)
88-
89-
for E, T, H in configs:
90-
time_ms, gflops, gbps = benchmark(gold, E, T, H)
91-
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
216+
def benchmark_full():
217+
configs = [
218+
(8, 32, 1024),
219+
(16, 64, 2048),
220+
(32, 128, 4096),
221+
# DeepSeekV3 Configs
222+
(256, 16, 7168),
223+
(256, 32, 7168),
224+
(256, 64, 7168),
225+
(256, 128, 7168),
226+
(256, 256, 7168),
227+
(256, 512, 7168),
228+
(256, 1024, 7168),
229+
]
230+
231+
print(f"GPU: {torch.cuda.get_device_name()} CUDA Kernel")
232+
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
233+
print("-" * 50)
234+
235+
for E, T, H in configs:
236+
time_ms, gflops, gbps = benchmark(
237+
silu_mul_fp8_quant_deep_gemm_cuda, E, T, H, 16
238+
)
239+
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
240+
241+
print(f"GPU: {torch.cuda.get_device_name()} Baseline")
242+
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
243+
print("-" * 50)
244+
245+
for E, T, H in configs:
246+
time_ms, gflops, gbps = benchmark(gold, E, T, H, 16)
247+
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
248+
249+
250+
benchmark_full()

0 commit comments

Comments
 (0)