Skip to content

Commit 97e8125

Browse files
committed
rename custom quant ops
Signed-off-by: Frida Hou <[email protected]> rename torch cumstom op Signed-off-by: Frida Hou <[email protected]>
1 parent 797f717 commit 97e8125

File tree

5 files changed

+17
-19
lines changed

5 files changed

+17
-19
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def _dequantize_nvfp4(
162162
return vals.view(N, K).to(orig_dtype)
163163

164164

165-
@torch.library.custom_op("auto_deploy::torch_quant_linear_fp8", mutates_args=())
166-
def torch_quant_linear_fp8(
165+
@torch.library.custom_op("auto_deploy::torch_fake_quant_fp8_linear", mutates_args=())
166+
def torch_fake_quant_fp8_linear(
167167
input: torch.Tensor,
168168
weight_quantized: torch.Tensor,
169169
bias: torch.Tensor, # Optional, no default
@@ -198,8 +198,8 @@ def torch_quant_linear_fp8(
198198
return out.reshape(*input.shape[:-1], out_features)
199199

200200

201-
@torch_quant_linear_fp8.register_fake
202-
def torch_quant_linear_fp8(
201+
@torch_fake_quant_fp8_linear.register_fake
202+
def torch_fake_quant_fp8_linear(
203203
input: torch.Tensor,
204204
weight_quantized: torch.Tensor,
205205
bias: torch.Tensor,
@@ -212,8 +212,8 @@ def torch_quant_linear_fp8(
212212
return torch.ops.aten.linear(input, w, bias)
213213

214214

215-
@torch.library.custom_op("auto_deploy::torch_quant_linear_fp4", mutates_args=())
216-
def torch_quant_linear_fp4(
215+
@torch.library.custom_op("auto_deploy::torch_fake_quant_fp4_linear", mutates_args=())
216+
def torch_fake_quant_fp4_linear(
217217
input: torch.Tensor,
218218
weight_quantized: torch.Tensor,
219219
bias: torch.Tensor, # Optional, no default
@@ -274,8 +274,8 @@ def torch_quant_linear_fp4(
274274
return out_2d.reshape(*input_shape[:-1], N)
275275

276276

277-
@torch_quant_linear_fp4.register_fake
278-
def torch_quant_linear_fp4(
277+
@torch_fake_quant_fp4_linear.register_fake
278+
def torch_fake_quant_fp4_linear(
279279
input: torch.Tensor,
280280
weight_quantized: torch.Tensor,
281281
bias: torch.Tensor,

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class QuantizationFusionMixin:
102102
that share the same input activation (parent node).
103103
104104
Subclasses must define:
105-
- target_op: the torch op identifying the quantized linear (e.g., torch.ops.auto_deploy.torch_quant_linear_fp8)
105+
- target_op: the torch op identifying the quantized linear
106106
- scale_groups: List[List[str]] describing how kwargs should be grouped, e.g.
107107
FP8 -> [["input_scale"], ["weight_scale"]]
108108
FP4 -> [["input_scale"], ["weight_scale", "alpha"]]
@@ -260,7 +260,7 @@ def _apply(
260260

261261
@TransformRegistry.register("fuse_fp8_gemms")
262262
class FuseFP8Gemms(QuantizationFusionMixin, BaseTransform):
263-
target_op = torch.ops.auto_deploy.torch_quant_linear_fp8
263+
target_op = torch.ops.auto_deploy.torch_fake_quant_fp8_linear
264264
scale_groups = [["input_scale"], ["weight_scale"]]
265265

266266
def fuse_rule(
@@ -298,7 +298,7 @@ def _apply(
298298

299299
@TransformRegistry.register("fuse_fp4_gemms")
300300
class FuseFP4Gemms(QuantizationFusionMixin, BaseTransform):
301-
target_op = torch.ops.auto_deploy.torch_quant_linear_fp4
301+
target_op = torch.ops.auto_deploy.torch_fake_quant_fp4_linear
302302
scale_groups = [["input_scale"], ["weight_scale", "alpha"]]
303303

304304
def fuse_rule(

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool:
249249

250250
if include_quantization:
251251
lin_ops.update(QUANT_LINEAR_OPS)
252-
lin_ops.update([torch.ops.auto_deploy.torch_quant_linear_fp8])
253-
lin_ops.update([torch.ops.auto_deploy.torch_quant_linear_fp4])
252+
lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp8_linear])
253+
lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp4_linear])
254254
return is_op(node, lin_ops)
255255

256256

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def target_op():
155155
@staticmethod
156156
def custom_op():
157157
"""Unified custom kernel entry-point for quantized linear."""
158-
return torch.ops.auto_deploy.torch_quant_linear_fp8
158+
return torch.ops.auto_deploy.torch_fake_quant_fp8_linear
159159

160160
@staticmethod
161161
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:
@@ -211,7 +211,7 @@ def target_op():
211211
@staticmethod
212212
def custom_op():
213213
"""Unified custom kernel entry-point for quantized linear."""
214-
return torch.ops.auto_deploy.torch_quant_linear_fp4
214+
return torch.ops.auto_deploy.torch_fake_quant_fp4_linear
215215

216216
@staticmethod
217217
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,14 @@ def test_quant_linear_fp8_matches_fused_op(bias):
128128
weight_scale=weight_scale,
129129
)
130130

131-
out_unified = torch.ops.auto_deploy.custom_quant_linear(
131+
out_unified = torch.ops.auto_deploy.torch_fake_quant_fp8_linear(
132132
input,
133133
weight_fp8,
134134
bias,
135135
[torch.tensor(1.0, device="cuda")],
136136
[weight_scale],
137137
[],
138138
[],
139-
format_type=FORMAT_FP8,
140139
)
141140

142141
assert out_unified.shape == out_fused.shape
@@ -184,7 +183,7 @@ def test_quant_linear_nvfp4_matches_fused_op(bias):
184183
alpha=alpha_fused,
185184
)
186185

187-
out_unified = torch.ops.auto_deploy.custom_quant_linear(
186+
out_unified = torch.ops.auto_deploy.torch_fake_quant_fp4_linear(
188187
x,
189188
weight_fp4,
190189
bias,
@@ -195,7 +194,6 @@ def test_quant_linear_nvfp4_matches_fused_op(bias):
195194
], # weight_scale list: [per-block vector, combined alpha]
196195
[], # input_zp
197196
[], # weight_zp
198-
format_type=FORMAT_NVFP4,
199197
)
200198

201199
assert out_unified.shape == out_fused.shape

0 commit comments

Comments
 (0)