Skip to content

Commit cf592e3

Browse files
committed
update quantize_linear_from_config to point to the custom op
Signed-off-by: Frida Hou <[email protected]> minor Signed-off-by: Frida Hou <[email protected]>
1 parent f2c609b commit cf592e3

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _insert_quantized_linear(
2929
quantization_impl: QuantizationImpl,
3030
is_quantized_graph: bool = False,
3131
):
32-
"""Replaces the matmul node with a new quantized matmul node.
32+
"""Replaces the matmul node with a new custom quantized linear node.
3333
3434
The state_dict is also updated to contain the sharded weights.
3535
"""
@@ -72,14 +72,17 @@ def _insert_quantized_linear(
7272
partial(quantization_impl.load_hook, weight_name=param_name)
7373
)
7474

75-
node.target = quantization_impl.target_op()
76-
7775
with gm.graph.inserting_before(node):
7876
scales = {}
7977
for scale_name in quantization_impl.scale_names():
8078
scales[scale_name] = gm.graph.create_node("get_attr", modname + "." + scale_name)
8179

82-
node.kwargs = {**node.kwargs, **scales}
80+
custom_kwargs = quantization_impl.build_custom_kwargs_for_linear(
81+
scales,
82+
)
83+
84+
node.target = quantization_impl.custom_op()
85+
node.kwargs = {**node.kwargs, **custom_kwargs}
8386

8487

8588
def _insert_quantized_bmm(

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
except ImportError:
2828
float4_sf_dtype = None
2929

30+
# TODO: put the ENUMs in the same place and import it
31+
FORMAT_FP8 = 0
32+
FORMAT_NVFP4 = 1
33+
3034

3135
def modelopt_fp4_scale_to_cutlass_fp4_scale(modelopt_scale: torch.Tensor) -> torch.Tensor:
3236
"""Converts the modelopt FP4 per-block weight scale to the cutlass format (padded and swizzled)."""
@@ -81,6 +85,7 @@ def create(quant_type_or_node: Union[str, Node], is_bmm: bool = False):
8185
quantization_impl_map = {
8286
"": None,
8387
"FP8": FP8BMMQuantizationImpl,
88+
"NVFP4": None, # BMM NVFP4 is not supported yet
8489
}
8590
else:
8691
quantization_impl_map = {
@@ -160,6 +165,18 @@ def shard_load_hook(
160165
def fuse_linear_weights(weights, **kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
161166
pass
162167

168+
@staticmethod
169+
def custom_op():
170+
"""Unified custom kernel entry-point for quantized linear."""
171+
return torch.ops.auto_deploy.custom_quant_linear
172+
173+
@staticmethod
174+
def build_custom_kwargs_for_linear(
175+
scale_getattrs: Dict[str, Node],
176+
) -> Dict[str, object]:
177+
"""Default: no extra kwargs. Each impl overrides to pass the right inputs/scales/zps/format."""
178+
return {}
179+
163180

164181
class FP8QuantizationImpl(QuantizationImpl):
165182
@staticmethod
@@ -180,6 +197,20 @@ def scale_names() -> List[str]:
180197
def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
181198
return {"input_scale": torch.tensor(1.0), "weight_scale": torch.tensor(1.0)}
182199

200+
@staticmethod
201+
def build_custom_kwargs_for_linear(
202+
scale_getattrs: Dict[str, Node],
203+
) -> Dict[str, object]:
204+
# FP8 custom op contract:
205+
# input_scale=[tensor], weight_scale=[tensor], input_zp=[], weight_zp=[], format_type=FORMAT_FP8
206+
return dict(
207+
input_scale=[scale_getattrs["input_scale"]],
208+
weight_scale=[scale_getattrs["weight_scale"]],
209+
input_zp=[],
210+
weight_zp=[],
211+
format_type=FORMAT_FP8,
212+
)
213+
183214
@staticmethod
184215
def load_hook(state_dict, prefix, *args, weight_name):
185216
if weight_name in state_dict:
@@ -264,6 +295,29 @@ def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
264295
"alpha": torch.tensor(1.0 / 6.0),
265296
}
266297

298+
@staticmethod
299+
def build_custom_kwargs_for_linear(
300+
scale_getattrs: Dict[str, Node],
301+
) -> Dict[str, object]:
302+
"""
303+
Contract:
304+
custom_quant_linear(
305+
x, Wq, bias,
306+
input_scale=[s_in2],
307+
weight_scale=[weight_scale_cutlass_uint8, alpha_fused],
308+
input_zp=[],
309+
weight_zp=[],
310+
format_type=FORMAT_NVFP4
311+
)
312+
"""
313+
return dict(
314+
input_scale=[scale_getattrs["input_scale"]],
315+
weight_scale=[scale_getattrs["weight_scale"], scale_getattrs["alpha"]],
316+
input_zp=[],
317+
weight_zp=[],
318+
format_type=FORMAT_NVFP4,
319+
)
320+
267321
@staticmethod
268322
def load_hook(state_dict, prefix, *args, weight_name):
269323
if weight_name in state_dict:

0 commit comments

Comments
 (0)