Skip to content

Commit d14d5f4

Browse files
committed
WIP to map custom quant op to real implementation using pattern matcher
Signed-off-by: Frida Hou <[email protected]> Update fusion to pass in args instead of kwrargs Signed-off-by: Frida Hou <[email protected]>
1 parent 97e8125 commit d14d5f4

File tree

6 files changed

+401
-11
lines changed

6 files changed

+401
-11
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ transforms:
9191
fuse_rmsnorm:
9292
stage: post_load_fusion
9393
backend: flashinfer
94+
fuse_quant:
95+
stage: post_load_fusion
9496
############################################################################################
9597
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
9698
############################################################################################
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
from typing import Tuple
2+
3+
import torch
4+
from torch.fx import GraphModule
5+
6+
from ...models.factory import ModelFactory
7+
from ...shim.interface import CachedSequenceInterface
8+
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
9+
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
10+
11+
12+
def _fp8_ref_pattern_1(
13+
x: torch.Tensor,
14+
w_fp8: torch.Tensor,
15+
input_scale: torch.Tensor,
16+
weight_scale: torch.Tensor,
17+
):
18+
# Matches: torch_fake_quant_fp8_linear(input, weight_fp8, bias, [in_s], [w_s], [], [])
19+
return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
20+
x,
21+
w_fp8,
22+
None,
23+
input_scale=[input_scale],
24+
weight_scale=[weight_scale],
25+
input_zp=[],
26+
weight_zp=[],
27+
)
28+
29+
30+
def _fp8_ref_repl_1(
31+
x: torch.Tensor,
32+
w_fp8: torch.Tensor,
33+
input_scale: torch.Tensor,
34+
weight_scale: torch.Tensor,
35+
):
36+
# Map lists -> scalars for fused op
37+
# in_s = input_scale[0]
38+
# w_s = weight_scale[0]
39+
return torch.ops.auto_deploy.torch_quant_fp8_linear(
40+
x,
41+
w_fp8,
42+
None,
43+
input_scale=input_scale,
44+
weight_scale=weight_scale,
45+
)
46+
47+
48+
def _fp8_ref_pattern_2(
49+
x: torch.Tensor,
50+
w_fp8: torch.Tensor,
51+
bias: torch.Tensor,
52+
input_scale: torch.Tensor,
53+
weight_scale: torch.Tensor,
54+
):
55+
# Matches: torch_fake_quant_fp8_linear(input, weight_fp8, bias, [in_s], [w_s], [], [])
56+
return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
57+
x,
58+
w_fp8,
59+
bias,
60+
input_scale=[input_scale],
61+
weight_scale=[weight_scale],
62+
input_zp=[],
63+
weight_zp=[],
64+
)
65+
66+
67+
def _fp8_ref_repl_2(
68+
x: torch.Tensor,
69+
w_fp8: torch.Tensor,
70+
bias: torch.Tensor,
71+
input_scale: torch.Tensor,
72+
weight_scale: torch.Tensor,
73+
):
74+
# Map lists -> scalars for fused op
75+
# in_s = input_scale[0]
76+
# w_s = weight_scale[0]
77+
return torch.ops.auto_deploy.torch_quant_fp8_linear(
78+
x,
79+
w_fp8,
80+
bias,
81+
input_scale=input_scale,
82+
weight_scale=weight_scale,
83+
)
84+
85+
86+
# NVFP4: reference (search) and fused (replacement)
87+
def _fp4_ref_pattern_1(
88+
x: torch.Tensor,
89+
w_fp4: torch.Tensor,
90+
input_scale: torch.Tensor,
91+
weight_scale: torch.Tensor,
92+
alpha: torch.Tensor,
93+
):
94+
# Matches: torch_fake_quant_fp4_linear(x, w_fp4, bias, [s_in2], [cutlass_scale, alpha], [], [])
95+
return torch.ops.auto_deploy.torch_fake_quant_fp4_linear(
96+
x,
97+
w_fp4,
98+
None,
99+
input_scale=[input_scale],
100+
weight_scale=[weight_scale, alpha],
101+
input_zp=[],
102+
weight_zp=[],
103+
)
104+
105+
106+
def _fp4_ref_repl_1(
107+
x: torch.Tensor,
108+
w_fp4: torch.Tensor,
109+
input_scale: torch.Tensor,
110+
weight_scale: torch.Tensor,
111+
alpha: torch.Tensor,
112+
):
113+
return torch.ops.auto_deploy.torch_quant_fp4_linear(
114+
x,
115+
w_fp4,
116+
bias=None,
117+
input_scale=input_scale,
118+
weight_scale=weight_scale,
119+
alpha=alpha,
120+
)
121+
122+
123+
def _fp4_ref_pattern_2(
124+
x: torch.Tensor,
125+
w_fp4: torch.Tensor,
126+
bias: torch.Tensor,
127+
input_scale: torch.Tensor,
128+
weight_scale: torch.Tensor,
129+
alpha: torch.Tensor,
130+
):
131+
# Matches: torch_fake_quant_fp4_linear(x, w_fp4, bias, [s_in2], [cutlass_scale, alpha], [], [])
132+
return torch.ops.auto_deploy.torch_fake_quant_fp4_linear(
133+
x,
134+
w_fp4,
135+
bias,
136+
input_scale=[input_scale],
137+
weight_scale=[weight_scale, alpha],
138+
input_zp=[],
139+
weight_zp=[],
140+
)
141+
142+
143+
def _fp4_ref_repl_2(
144+
x: torch.Tensor,
145+
w_fp4: torch.Tensor,
146+
bias: torch.Tensor | None,
147+
input_scale: torch.Tensor,
148+
weight_scale: torch.Tensor,
149+
alpha: torch.Tensor,
150+
):
151+
return torch.ops.auto_deploy.torch_quant_fp4_linear(
152+
x,
153+
w_fp4,
154+
bias=bias,
155+
input_scale=input_scale,
156+
weight_scale=weight_scale,
157+
alpha=alpha,
158+
)
159+
160+
161+
def _register_quant_linear_patterns(patterns: ADPatternMatcherPass) -> None:
162+
"""
163+
Register the FP8 and FP4 patterns with robust dummy args and minimal ignores.
164+
"""
165+
# Use harmless meta tensors; no dtype/device constraints during tracing.
166+
# Shapes mirror your unit tests but can be arbitrary as long as tracing succeeds.
167+
x_fp8 = torch.randn(3, 16, device="meta", dtype=torch.float16)
168+
w_fp8 = torch.randn(32, 16, device="meta", dtype=torch.float16) # dtype not enforced in trace
169+
bias32 = torch.randn(32, device="meta", dtype=torch.float32)
170+
one = torch.tensor(1.0, device="meta", dtype=torch.float32)
171+
172+
dummy_args_fp8 = [
173+
x_fp8,
174+
w_fp8,
175+
one,
176+
torch.tensor(0.5, device="meta", dtype=torch.float32),
177+
]
178+
179+
dummy_args_fp8_2 = [
180+
x_fp8,
181+
w_fp8,
182+
bias32,
183+
one,
184+
torch.tensor(0.5, device="meta", dtype=torch.float32),
185+
]
186+
187+
register_ad_pattern(
188+
search_fn=_fp8_ref_pattern_1,
189+
replace_fn=_fp8_ref_repl_1,
190+
patterns=patterns,
191+
dummy_args=dummy_args_fp8,
192+
# No special scalar_workaround or op_ignore_types needed here.
193+
)
194+
register_ad_pattern(
195+
search_fn=_fp8_ref_pattern_2,
196+
replace_fn=_fp8_ref_repl_2,
197+
patterns=patterns,
198+
dummy_args=dummy_args_fp8_2,
199+
# No special scalar_workaround or op_ignore_types needed here.
200+
)
201+
202+
# FP4 dummy args
203+
N = 32
204+
K_packed = 32 # weight is packed by 2 FP4 per byte
205+
K_eff = 2 * K_packed # <- effective K after repeat(1, 2) in the fake impl
206+
207+
x_fp4 = torch.randn(3, K_eff, device="meta", dtype=torch.float16) # was 3 x 32, must be 3 x 64
208+
w_fp4 = torch.randint(0, 255, (N, K_packed), device="meta", dtype=torch.uint8)
209+
210+
s_in2 = torch.tensor(0.01, device="meta", dtype=torch.float32)
211+
alpha = torch.tensor(1.2345, device="meta", dtype=torch.float32)
212+
213+
# Optional: give a realistic-length CUTLASS scale vector (one uint8 per 16-wide block)
214+
# num_blocks = N * (K_eff // 16)
215+
cutlass_len = N * (K_eff // 16) # 32 * (64/16) = 128
216+
cutlass_vec = torch.randint(0, 255, (cutlass_len,), device="meta", dtype=torch.uint8)
217+
218+
dummy_args_fp4_1 = [
219+
x_fp4,
220+
w_fp4,
221+
s_in2, # input_scale list
222+
cutlass_vec,
223+
alpha, # weight_scale list: [per-block vec, alpha]
224+
]
225+
226+
dummy_args_fp4_2 = [
227+
x_fp4,
228+
w_fp4,
229+
torch.randn(N, device="meta", dtype=torch.float16), # bias
230+
s_in2, # input_scale list
231+
cutlass_vec,
232+
alpha, # weight_scale list: [per-block vec, alpha]
233+
]
234+
235+
register_ad_pattern(
236+
search_fn=_fp4_ref_pattern_1,
237+
replace_fn=_fp4_ref_repl_1,
238+
patterns=patterns,
239+
dummy_args=dummy_args_fp4_1,
240+
)
241+
242+
register_ad_pattern(
243+
search_fn=_fp4_ref_pattern_2,
244+
replace_fn=_fp4_ref_repl_2,
245+
patterns=patterns,
246+
dummy_args=dummy_args_fp4_2,
247+
)
248+
249+
250+
@TransformRegistry.register("fuse_quant")
251+
class FuseQuant(BaseTransform):
252+
"""
253+
Use ADPatternMatcherPass to rewrite reference quantized linear ops into fused ones:
254+
255+
FP8:
256+
torch_fake_quant_fp8_linear(x, w_fp8, bias, [in_s], [w_s], [], [])
257+
-> torch_quant_fp8_linear(x, w_fp8, bias=bias, input_scale=in_s, weight_scale=w_s)
258+
259+
FP4 (NVFP4):
260+
torch_fake_quant_fp4_linear(x, w_fp4, bias, [s_in2], [cutlass_vec, alpha], [], [])
261+
-> torch_quant_fp4_linear(x, w_fp4, bias=bias, input_scale=s_in2,
262+
weight_scale=cutlass_vec, alpha=alpha)
263+
"""
264+
265+
def _apply(
266+
self,
267+
gm: GraphModule,
268+
cm: CachedSequenceInterface,
269+
factory: ModelFactory,
270+
shared_config: SharedConfig,
271+
) -> Tuple[GraphModule, TransformInfo]:
272+
patterns = ADPatternMatcherPass()
273+
_register_quant_linear_patterns(patterns)
274+
num_matches = patterns.apply(gm.graph)
275+
276+
info = TransformInfo(
277+
skipped=(num_matches == 0),
278+
num_matches=num_matches,
279+
is_clean=False,
280+
has_valid_shapes=False,
281+
)
282+
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def fuse_rule(
125125
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
126126
raise NotImplementedError
127127

128+
def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple[object, ...]:
129+
"""Return the *positional* tail after bias for the fused call."""
130+
raise NotImplementedError
131+
128132
def _insert_fused_quant_gemm(
129133
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
130134
):
@@ -163,16 +167,17 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]:
163167

164168
# For each kwarg group (e.g., input_scale, weight_scale[, alpha]),
165169
# create a list of get_attr nodes in the same structure the op expects.
166-
for group, kwarg_name in zip(self.scale_groups, ["input_scale", "weight_scale"]):
167-
fused_kwargs[kwarg_name] = [
168-
gm.graph.create_node("get_attr", f"{key_fused}_{name}") for name in group
169-
]
170+
scale_getattrs: Dict[str, Node] = {
171+
name: gm.graph.create_node("get_attr", f"{key_fused}_{name}")
172+
for name in flat_scale_names
173+
}
174+
custom_tail_args = self.build_custom_args_for_linear(scale_getattrs)
170175

171176
# add new linear node + split node
172177
with gm.graph.inserting_before(linear_nodes[0]):
173178
fused_linear_node = gm.graph.call_function(
174179
get_op_overload_packet(linear_nodes[0].target),
175-
args=(parent_node, get_param_node, None),
180+
args=(parent_node, get_param_node, None, *custom_tail_args),
176181
kwargs=fused_kwargs,
177182
)
178183
split_node = gm.graph.call_function(split_output, args=(fused_linear_node,))
@@ -286,6 +291,15 @@ def fuse_rule(
286291
"input_scale": input_scale[0].clone(),
287292
}
288293

294+
def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple[object, ...]:
295+
# (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list))
296+
return (
297+
[scale_getattrs["input_scale"]],
298+
[scale_getattrs["weight_scale"]],
299+
[],
300+
[],
301+
)
302+
289303
def _apply(
290304
self,
291305
gm: GraphModule,
@@ -323,6 +337,15 @@ def fuse_rule(
323337
"input_scale": input_scale[0].clone(),
324338
}
325339

340+
def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple[object, ...]:
341+
# (..., bias, input_scale(list), weight_scale(list-with-alpha), input_zp(list), weight_zp(list))
342+
return (
343+
[scale_getattrs["input_scale"]],
344+
[scale_getattrs["weight_scale"], scale_getattrs["alpha"]],
345+
[],
346+
[],
347+
)
348+
326349
def _apply(
327350
self,
328351
gm: GraphModule,

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,10 @@ def _insert_quantized_linear(
7777
for scale_name in quantization_impl.scale_names():
7878
scales[scale_name] = gm.graph.create_node("get_attr", modname + "." + scale_name)
7979

80-
custom_kwargs = quantization_impl.build_custom_kwargs_for_linear(
81-
scales,
82-
)
80+
custom_args = quantization_impl.build_custom_args_for_linear(scales)
8381

8482
node.target = quantization_impl.custom_op()
85-
node.kwargs = {**node.kwargs, **custom_kwargs}
83+
node.args = (*node.args, *custom_args)
8684

8785

8886
def _insert_quantized_bmm(

0 commit comments

Comments
 (0)