Skip to content

Commit 0974df5

Browse files
committed
minor updates: rabbit feedback, docstrings, code cleaning
Signed-off-by: Frida Hou <[email protected]>
1 parent f4463a5 commit 0974df5

File tree

11 files changed

+74
-158
lines changed

11 files changed

+74
-158
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _expect_single_scale(scales: List[Optional[torch.Tensor]], name: str) -> tor
2020
return scales[0]
2121

2222

23-
def _to_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
23+
def _to_fp8_fake(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
2424
return (x / scale).to(torch.float8_e4m3fn)
2525

2626

@@ -84,7 +84,7 @@ def _cast_fp4(weight: torch.Tensor):
8484

8585
sign_bit = (weight < 0).to(torch.uint8)
8686

87-
weight_abs = weight.abs_()
87+
weight_abs = weight.abs() # avoid in-place modification to input
8888
# Calculate the ordinal value based on the bounds
8989
ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8)
9090
# All values equal to e2m1_bounds at odd indices are rounded up and even indices are rounded down
@@ -105,7 +105,7 @@ def _quantize_nvfp4(
105105
block_size (int): The size of each block for quantization.
106106
weights_scaling_factor_2 (torch.Tensor): The per-tensor scaling factor for the weights.
107107
Returns:
108-
tuple: Contains quantized data, quantized per block scaling factor, and per block scaling factor.
108+
tuple: Contains quantized data and quantized per block scaling factor
109109
"""
110110

111111
weights_scaling_factor, weights_scaling_factor_2 = _nvfp4_get_weights_scaling_factor(
@@ -160,11 +160,11 @@ def _dequantize_nvfp4(
160160
def torch_fake_quant_fp8_linear(
161161
input: torch.Tensor,
162162
weight_quantized: torch.Tensor,
163-
bias: torch.Tensor, # Optional, no default
164-
input_scale: List[torch.Tensor], # Tensor?[] (REQUIRED: no default)
165-
weight_scale: List[torch.Tensor], # Tensor?[]
166-
input_zp: List[torch.Tensor], # Tensor?[]
167-
weight_zp: List[torch.Tensor], # Tensor?[]
163+
bias: torch.Tensor,
164+
input_scale: List[torch.Tensor],
165+
weight_scale: List[torch.Tensor],
166+
input_zp: List[torch.Tensor],
167+
weight_zp: List[torch.Tensor],
168168
) -> torch.Tensor:
169169
"""
170170
Reference (eager) implementation for multiple quant formats via `format_type`.
@@ -180,7 +180,7 @@ def torch_fake_quant_fp8_linear(
180180
in_dtype = input.dtype
181181
out_features, in_features = weight_quantized.shape
182182

183-
input_fp8 = _to_fp8(input, s_in)
183+
input_fp8 = _to_fp8_fake(input, s_in)
184184
input_deq = _from_fp8(input_fp8, s_in, in_dtype)
185185

186186
weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)
@@ -209,11 +209,11 @@ def torch_fake_quant_fp8_linear(
209209
def torch_fake_quant_fp4_linear(
210210
input: torch.Tensor,
211211
weight_quantized: torch.Tensor,
212-
bias: torch.Tensor, # Optional, no default
213-
input_scale: List[torch.Tensor], # Tensor?[] (REQUIRED: no default)
214-
weight_scale: List[torch.Tensor], # Tensor?[]
215-
input_zp: List[torch.Tensor], # Tensor?[]
216-
weight_zp: List[torch.Tensor], # Tensor?[]
212+
bias: torch.Tensor,
213+
input_scale: List[torch.Tensor],
214+
weight_scale: List[torch.Tensor],
215+
input_zp: List[torch.Tensor],
216+
weight_zp: List[torch.Tensor],
217217
) -> torch.Tensor:
218218
"""
219219
Reference (eager) implementation for multiple quant formats via `format_type`.

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
1010

1111

12+
# with bias=None
1213
def _fp8_ref_pattern_1(
1314
x: torch.Tensor,
1415
w_fp8: torch.Tensor,
1516
input_scale: torch.Tensor,
1617
weight_scale: torch.Tensor,
1718
):
18-
# Matches: torch_fake_quant_fp8_linear(input, weight_fp8, bias, [in_s], [w_s], [], [])
1919
return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
2020
x,
2121
w_fp8,
@@ -33,9 +33,6 @@ def _fp8_ref_repl_1(
3333
input_scale: torch.Tensor,
3434
weight_scale: torch.Tensor,
3535
):
36-
# Map lists -> scalars for fused op
37-
# in_s = input_scale[0]
38-
# w_s = weight_scale[0]
3936
return torch.ops.auto_deploy.torch_quant_fp8_linear(
4037
x,
4138
w_fp8,
@@ -45,14 +42,14 @@ def _fp8_ref_repl_1(
4542
)
4643

4744

45+
# with bias!=None
4846
def _fp8_ref_pattern_2(
4947
x: torch.Tensor,
5048
w_fp8: torch.Tensor,
5149
bias: torch.Tensor,
5250
input_scale: torch.Tensor,
5351
weight_scale: torch.Tensor,
5452
):
55-
# Matches: torch_fake_quant_fp8_linear(input, weight_fp8, bias, [in_s], [w_s], [], [])
5653
return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
5754
x,
5855
w_fp8,
@@ -71,9 +68,6 @@ def _fp8_ref_repl_2(
7168
input_scale: torch.Tensor,
7269
weight_scale: torch.Tensor,
7370
):
74-
# Map lists -> scalars for fused op
75-
# in_s = input_scale[0]
76-
# w_s = weight_scale[0]
7771
return torch.ops.auto_deploy.torch_quant_fp8_linear(
7872
x,
7973
w_fp8,
@@ -83,15 +77,14 @@ def _fp8_ref_repl_2(
8377
)
8478

8579

86-
# NVFP4: reference (search) and fused (replacement)
80+
# NVFP4: with bias=None
8781
def _fp4_ref_pattern_1(
8882
x: torch.Tensor,
8983
w_fp4: torch.Tensor,
9084
input_scale: torch.Tensor,
9185
weight_scale: torch.Tensor,
9286
alpha: torch.Tensor,
9387
):
94-
# Matches: torch_fake_quant_fp4_linear(x, w_fp4, bias, [s_in2], [cutlass_scale, alpha], [], [])
9588
return torch.ops.auto_deploy.torch_fake_quant_fp4_linear(
9689
x,
9790
w_fp4,
@@ -120,6 +113,7 @@ def _fp4_ref_repl_1(
120113
)
121114

122115

116+
# with bias!=None
123117
def _fp4_ref_pattern_2(
124118
x: torch.Tensor,
125119
w_fp4: torch.Tensor,
@@ -128,7 +122,6 @@ def _fp4_ref_pattern_2(
128122
weight_scale: torch.Tensor,
129123
alpha: torch.Tensor,
130124
):
131-
# Matches: torch_fake_quant_fp4_linear(x, w_fp4, bias, [s_in2], [cutlass_scale, alpha], [], [])
132125
return torch.ops.auto_deploy.torch_fake_quant_fp4_linear(
133126
x,
134127
w_fp4,
@@ -162,10 +155,8 @@ def _register_quant_linear_patterns(patterns: ADPatternMatcherPass) -> None:
162155
"""
163156
Register the FP8 and FP4 patterns with robust dummy args and minimal ignores.
164157
"""
165-
# Use harmless meta tensors; no dtype/device constraints during tracing.
166-
# Shapes mirror your unit tests but can be arbitrary as long as tracing succeeds.
167158
x_fp8 = torch.randn(3, 16, device="meta", dtype=torch.float16)
168-
w_fp8 = torch.randn(32, 16, device="meta", dtype=torch.float16) # dtype not enforced in trace
159+
w_fp8 = torch.randn(32, 16, device="meta", dtype=torch.float16)
169160
bias32 = torch.randn(32, device="meta", dtype=torch.float32)
170161
one = torch.tensor(1.0, device="meta", dtype=torch.float32)
171162

@@ -189,47 +180,43 @@ def _register_quant_linear_patterns(patterns: ADPatternMatcherPass) -> None:
189180
replace_fn=_fp8_ref_repl_1,
190181
patterns=patterns,
191182
dummy_args=dummy_args_fp8,
192-
# No special scalar_workaround or op_ignore_types needed here.
193183
)
194184
register_ad_pattern(
195185
search_fn=_fp8_ref_pattern_2,
196186
replace_fn=_fp8_ref_repl_2,
197187
patterns=patterns,
198188
dummy_args=dummy_args_fp8_2,
199-
# No special scalar_workaround or op_ignore_types needed here.
200189
)
201190

202191
# FP4 dummy args
203192
N = 32
204193
K_packed = 32 # weight is packed by 2 FP4 per byte
205-
K_eff = 2 * K_packed # <- effective K after repeat(1, 2) in the fake impl
194+
K_eff = 2 * K_packed
206195

207-
x_fp4 = torch.randn(3, K_eff, device="meta", dtype=torch.float16) # was 3 x 32, must be 3 x 64
196+
x_fp4 = torch.randn(3, K_eff, device="meta", dtype=torch.float16)
208197
w_fp4 = torch.randint(0, 255, (N, K_packed), device="meta", dtype=torch.uint8)
209198

210199
s_in2 = torch.tensor(0.01, device="meta", dtype=torch.float32)
211200
alpha = torch.tensor(1.2345, device="meta", dtype=torch.float32)
212201

213-
# Optional: give a realistic-length CUTLASS scale vector (one uint8 per 16-wide block)
214-
# num_blocks = N * (K_eff // 16)
215202
cutlass_len = N * (K_eff // 16) # 32 * (64/16) = 128
216203
cutlass_vec = torch.randint(0, 255, (cutlass_len,), device="meta", dtype=torch.uint8)
217204

218205
dummy_args_fp4_1 = [
219206
x_fp4,
220207
w_fp4,
221-
s_in2, # input_scale list
208+
s_in2,
222209
cutlass_vec,
223-
alpha, # weight_scale list: [per-block vec, alpha]
210+
alpha,
224211
]
225212

226213
dummy_args_fp4_2 = [
227214
x_fp4,
228215
w_fp4,
229216
torch.randn(N, device="meta", dtype=torch.float16), # bias
230-
s_in2, # input_scale list
217+
s_in2,
231218
cutlass_vec,
232-
alpha, # weight_scale list: [per-block vec, alpha]
219+
alpha,
233220
]
234221

235222
register_ad_pattern(

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,9 @@ class QuantizationFusionMixin:
115115
fused_buffers: Dict[str, Tensor] to register as buffers on the fused module
116116
"""
117117

118-
# required class attributes in subclasses:
119118
target_op: Callable
120119
scale_groups: List[List[str]]
121120

122-
# required method in subclasses:
123121
def fuse_rule(
124122
self, weights: List[torch.Tensor], **scales
125123
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
@@ -141,7 +139,7 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]:
141139
"""Split the output tensor of the fused linear node to obtain the original outputs."""
142140
return tuple(t.contiguous() for t in torch.split(tensor, sizes_unfused, dim=-1))
143141

144-
# 2) Load scale buffers grouped by flattened scale names
142+
# Load scale buffers grouped by flattened scale names
145143
flat_scale_names = list(chain.from_iterable(self.scale_groups))
146144
scales: Dict[str, List[torch.Tensor]] = {}
147145
for weight_key in keys_unfused:

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _insert_quantized_linear(
7979

8080
custom_args = quantization_impl.build_custom_args_for_linear(scales)
8181

82-
node.target = quantization_impl.custom_op()
82+
node.target = quantization_impl.target_op()
8383
node.args = (*node.args, *custom_args)
8484

8585

@@ -195,7 +195,6 @@ def _apply(
195195
impl = QuantizationImpl.create(quant_algo, is_bmm=False)
196196

197197
for n in gm.graph.nodes:
198-
# Only consider linear ops; skip if excluded
199198
if not is_linear_op(n, include_quantization=False):
200199
continue
201200
if should_skip_quantization(n, excluded):
@@ -236,7 +235,6 @@ def _apply(
236235
for n in gm.graph.nodes:
237236
if not is_bmm_op(n):
238237
continue
239-
# Reuse common exclusion rule (supports Node or param-name string)
240238
if should_skip_quantization(n, excluded):
241239
continue
242240

tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,10 @@
3434
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
3535

3636

37+
# Copied from torch._dynamo.utils.detect_fake_mode but skip the same FakeMode assertion
38+
# In our use case, FakeMode of the inserted replacement pattern is different from the original
39+
# FakeMode from graph, which breaks this assertion
3740
def ad_detect_fake_mode(inputs: Any = None):
38-
"""
39-
Attempts to "detect" what the current fake mode is. If there is one ambiently
40-
available from TracingContext, we preferentially use that. Otherwise, we
41-
heuristically detect the fake mode via the following sources, in order of
42-
priority:
43-
44-
- Currently active fake mode on stack
45-
- Fake mode associated with passed in tensors (inputs does not
46-
have to be flattened)
47-
"""
4841
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
4942

5043
fake_modes = []
@@ -72,7 +65,6 @@ def ad_detect_fake_mode(inputs: Any = None):
7265
return None
7366

7467

75-
# Replace the function used as a context manager
7668
torch._dynamo.utils.detect_fake_mode = ad_detect_fake_mode
7769

7870

0 commit comments

Comments
 (0)