Skip to content

Commit e0f0a92

Browse files
bnellnmdebroy-rh
authored andcommitted
[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (vllm-project#22537)
Signed-off-by: Bill Nell <[email protected]>
1 parent 00286c9 commit e0f0a92

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+2675
-2503
lines changed

benchmarks/kernels/benchmark_cutlass_fp4_moe.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
from vllm import _custom_ops as ops
1515
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
16+
from vllm.model_executor.layers.fused_moe.config import (
17+
fp8_w8a8_moe_quant_config,
18+
nvfp4_moe_quant_config,
19+
)
1620
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
1721
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
1822
from vllm.scalar_type import scalar_types
@@ -140,17 +144,20 @@ def run_triton_moe(
140144
a_fp8_scale: torch.Tensor,
141145
num_repeats: int,
142146
):
147+
quant_config = fp8_w8a8_moe_quant_config(
148+
w1_scale=w1_scale,
149+
w2_scale=w2_scale,
150+
a1_scale=a_fp8_scale,
151+
)
152+
143153
for _ in range(num_repeats):
144154
fused_experts(
145155
a,
146156
w1,
147157
w2,
148158
topk_weights,
149159
topk_ids,
150-
use_fp8_w8a8=True,
151-
w1_scale=w1_scale,
152-
w2_scale=w2_scale,
153-
a1_scale=a_fp8_scale,
160+
quant_config=quant_config,
154161
)
155162

156163
def run_cutlass_moe_fp4(
@@ -172,25 +179,27 @@ def run_cutlass_moe_fp4(
172179
device: torch.device,
173180
num_repeats: int,
174181
):
182+
quant_config = nvfp4_moe_quant_config(
183+
a1_gscale=a1_gs,
184+
a2_gscale=a2_gs,
185+
w1_scale=w1_blockscale,
186+
w2_scale=w2_blockscale,
187+
g1_alphas=w1_gs,
188+
g2_alphas=w2_gs,
189+
)
175190
for _ in range(num_repeats):
176191
with nvtx.annotate("cutlass_moe_fp4", color="green"):
177192
cutlass_moe_fp4(
178193
a=a,
179-
a1_gscale=a1_gs,
180-
a2_gscale=a2_gs,
181194
w1_fp4=w1_fp4,
182-
w1_blockscale=w1_blockscale,
183-
w1_alphas=w1_gs,
184195
w2_fp4=w2_fp4,
185-
w2_blockscale=w2_blockscale,
186-
w2_alphas=w2_gs,
187196
topk_weights=topk_weights,
188197
topk_ids=topk_ids,
189198
m=m,
190199
n=n,
191200
k=k,
192201
e=num_experts,
193-
device=device,
202+
quant_config=quant_config,
194203
)
195204

196205
def run_cutlass_from_graph(
@@ -211,26 +220,29 @@ def run_cutlass_from_graph(
211220
e: int,
212221
device: torch.device,
213222
):
223+
quant_config = nvfp4_moe_quant_config(
224+
a1_gscale=a1_gs,
225+
a2_gscale=a2_gs,
226+
w1_scale=w1_blockscale,
227+
w2_scale=w2_blockscale,
228+
g1_alphas=w1_gs,
229+
g2_alphas=w2_gs,
230+
)
231+
214232
with set_current_vllm_config(
215233
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
216234
):
217235
return cutlass_moe_fp4(
218236
a=a,
219-
a1_gscale=a1_gs,
220237
w1_fp4=w1_fp4,
221-
w1_blockscale=w1_blockscale,
222-
w1_alphas=w1_alphas,
223-
a2_gscale=a2_gs,
224238
w2_fp4=w2_fp4,
225-
w2_blockscale=w2_blockscale,
226-
w2_alphas=w2_alphas,
227239
topk_weights=topk_weights,
228240
topk_ids=topk_ids,
229241
m=m,
230242
n=n,
231243
k=k,
232244
e=num_experts,
233-
device=device,
245+
quant_config=quant_config,
234246
)
235247

236248
def run_triton_from_graph(
@@ -246,16 +258,18 @@ def run_triton_from_graph(
246258
with set_current_vllm_config(
247259
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
248260
):
261+
quant_config = fp8_w8a8_moe_quant_config(
262+
w1_scale=w1_scale,
263+
w2_scale=w2_scale,
264+
a1_scale=a_fp8_scale,
265+
)
249266
return fused_experts(
250267
a,
251268
w1,
252269
w2,
253270
topk_weights,
254271
topk_ids,
255-
use_fp8_w8a8=True,
256-
w1_scale=w1_scale,
257-
w2_scale=w2_scale,
258-
a1_scale=a_fp8_scale,
272+
quant_config=quant_config,
259273
)
260274

261275
def replay_graph(graph, num_repeats):

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm import _custom_ops as ops
99
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10+
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
1011
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
1112
from vllm.model_executor.layers.fused_moe.fused_moe import (
1213
fused_experts,
@@ -96,17 +97,19 @@ def run_triton_moe(
9697
a_scale: torch.Tensor,
9798
num_repeats: int,
9899
):
100+
quant_config = fp8_w8a8_moe_quant_config(
101+
w1_scale=w1_scale,
102+
w2_scale=w2_scale,
103+
a1_scale=a_scale,
104+
)
99105
for _ in range(num_repeats):
100106
fused_experts(
101107
a,
102108
w1,
103109
w2,
104110
topk_weights,
105111
topk_ids,
106-
use_fp8_w8a8=True,
107-
w1_scale=w1_scale,
108-
w2_scale=w2_scale,
109-
a1_scale=a_scale,
112+
quant_config=quant_config,
110113
)
111114

112115
def run_cutlass_moe(
@@ -125,21 +128,24 @@ def run_cutlass_moe(
125128
per_act_token: bool,
126129
num_repeats: int,
127130
):
131+
quant_config = fp8_w8a8_moe_quant_config(
132+
w1_scale=w1_scale,
133+
w2_scale=w2_scale,
134+
per_act_token_quant=per_act_token,
135+
)
136+
128137
for _ in range(num_repeats):
129138
cutlass_moe_fp8(
130139
a,
131140
w1,
132141
w2,
133142
topk_weights,
134143
topk_ids,
135-
w1_scale,
136-
w2_scale,
137144
ab_strides1,
138145
ab_strides2,
139146
c_strides1,
140147
c_strides2,
141-
per_act_token,
142-
a1_scale=None,
148+
quant_config=quant_config,
143149
)
144150

145151
def run_cutlass_from_graph(
@@ -156,6 +162,12 @@ def run_cutlass_from_graph(
156162
topk_weights: torch.Tensor,
157163
topk_ids: torch.Tensor,
158164
):
165+
quant_config = fp8_w8a8_moe_quant_config(
166+
w1_scale=w1_scale,
167+
w2_scale=w2_scale,
168+
per_act_token_quant=per_act_token,
169+
)
170+
159171
with set_current_vllm_config(
160172
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
161173
):
@@ -165,14 +177,11 @@ def run_cutlass_from_graph(
165177
w2_q,
166178
topk_weights,
167179
topk_ids,
168-
w1_scale,
169-
w2_scale,
170180
ab_strides1,
171181
ab_strides2,
172182
c_strides1,
173183
c_strides2,
174-
per_act_token,
175-
a1_scale=None,
184+
quant_config=quant_config,
176185
)
177186

178187
def run_triton_from_graph(
@@ -185,6 +194,11 @@ def run_triton_from_graph(
185194
w2_scale: torch.Tensor,
186195
a_scale: torch.Tensor,
187196
):
197+
quant_config = fp8_w8a8_moe_quant_config(
198+
w1_scale=w1_scale,
199+
w2_scale=w2_scale,
200+
a1_scale=a_scale,
201+
)
188202
with set_current_vllm_config(
189203
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
190204
):
@@ -194,10 +208,7 @@ def run_triton_from_graph(
194208
w2,
195209
topk_weights,
196210
topk_ids,
197-
use_fp8_w8a8=True,
198-
w1_scale=w1_scale,
199-
w2_scale=w2_scale,
200-
a1_scale=a_scale,
211+
quant_config=quant_config,
201212
)
202213

203214
def replay_graph(graph, num_repeats):

benchmarks/kernels/benchmark_moe.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import torch
1515
from ray.experimental.tqdm_ray import tqdm
1616

17+
from vllm.model_executor.layers.fused_moe.config import (
18+
FusedMoEQuantConfig,
19+
_get_config_dtype_str,
20+
)
1721
from vllm.model_executor.layers.fused_moe.fused_moe import *
1822
from vllm.platforms import current_platform
1923
from vllm.transformers_utils.config import get_config
@@ -134,43 +138,36 @@ def prepare(i: int):
134138
def run():
135139
from vllm.model_executor.layers.fused_moe import override_config
136140

141+
if use_fp8_w8a8:
142+
quant_dtype = torch.float8_e4m3fn
143+
elif use_int8_w8a16:
144+
quant_dtype = torch.int8
145+
else:
146+
quant_dtype = None
147+
148+
quant_config = FusedMoEQuantConfig.make(
149+
quant_dtype=quant_dtype,
150+
w1_scale=w1_scale,
151+
w2_scale=w2_scale,
152+
a1_scale=a1_scale,
153+
a2_scale=a2_scale,
154+
block_shape=block_quant_shape,
155+
)
156+
137157
with override_config(config):
138-
if use_deep_gemm:
139-
topk_weights, topk_ids, token_expert_indices = fused_topk(
140-
x, input_gating, topk, False
141-
)
142-
return fused_experts(
143-
x,
144-
w1,
145-
w2,
146-
topk_weights,
147-
topk_ids,
148-
inplace=True,
149-
use_fp8_w8a8=use_fp8_w8a8,
150-
w1_scale=w1_scale,
151-
w2_scale=w2_scale,
152-
a1_scale=a1_scale,
153-
a2_scale=a2_scale,
154-
block_shape=block_quant_shape,
155-
allow_deep_gemm=True,
156-
)
157-
else:
158-
fused_moe(
159-
x,
160-
w1,
161-
w2,
162-
input_gating,
163-
topk,
164-
renormalize=True,
165-
inplace=True,
166-
use_fp8_w8a8=use_fp8_w8a8,
167-
use_int8_w8a16=use_int8_w8a16,
168-
w1_scale=w1_scale,
169-
w2_scale=w2_scale,
170-
a1_scale=a1_scale,
171-
a2_scale=a2_scale,
172-
block_shape=block_quant_shape,
173-
)
158+
topk_weights, topk_ids, token_expert_indices = fused_topk(
159+
x, input_gating, topk, renormalize=not use_deep_gemm
160+
)
161+
return fused_experts(
162+
x,
163+
w1,
164+
w2,
165+
topk_weights,
166+
topk_ids,
167+
inplace=True,
168+
quant_config=quant_config,
169+
allow_deep_gemm=use_deep_gemm,
170+
)
174171

175172
# JIT compilation & warmup
176173
run()
@@ -414,7 +411,7 @@ def benchmark(
414411
use_deep_gemm: bool = False,
415412
) -> tuple[dict[str, int], float]:
416413
current_platform.seed_everything(self.seed)
417-
dtype_str = get_config_dtype_str(
414+
dtype_str = _get_config_dtype_str(
418415
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
419416
)
420417
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
@@ -547,7 +544,7 @@ def save_configs(
547544
block_quant_shape: list[int],
548545
save_dir: str,
549546
) -> None:
550-
dtype_str = get_config_dtype_str(
547+
dtype_str = _get_config_dtype_str(
551548
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
552549
)
553550

0 commit comments

Comments
 (0)