Skip to content

Commit 2c0b4f0

Browse files
committed
update quantize_linear_from_config to point to the custom op
Signed-off-by: Frida Hou <[email protected]>
1 parent f8ca1fb commit 2c0b4f0

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _insert_quantized_linear(
2929
quantization_impl: QuantizationImpl,
3030
is_quantized_graph: bool = False,
3131
):
32-
"""Replaces the matmul node with a new quantized matmul node.
32+
"""Replaces the matmul node with a new custom quantized linear node.
3333
3434
The state_dict is also updated to contain the sharded weights.
3535
"""
@@ -72,14 +72,17 @@ def _insert_quantized_linear(
7272
partial(quantization_impl.load_hook, weight_name=param_name)
7373
)
7474

75-
node.target = quantization_impl.target_op()
76-
7775
with gm.graph.inserting_before(node):
7876
scales = {}
7977
for scale_name in quantization_impl.scale_names():
8078
scales[scale_name] = gm.graph.create_node("get_attr", modname + "." + scale_name)
8179

82-
node.kwargs = {**node.kwargs, **scales}
80+
custom_kwargs = quantization_impl.build_custom_kwargs_for_linear(
81+
scales,
82+
)
83+
84+
node.target = quantization_impl.custom_op()
85+
node.kwargs = {**node.kwargs, **custom_kwargs}
8386

8487

8588
def _insert_quantized_bmm(

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
except ImportError:
2828
float4_sf_dtype = None
2929

30+
# TODO: put the ENUMs in the same place and import it
31+
FORMAT_FP8 = 0
32+
FORMAT_NVFP4 = 1
33+
3034

3135
def modelopt_fp4_scale_to_cutlass_fp4_scale(modelopt_scale: torch.Tensor) -> torch.Tensor:
3236
"""Converts the modelopt FP4 per-block weight scale to the cutlass format (padded and swizzled)."""
@@ -160,6 +164,18 @@ def shard_load_hook(
160164
def fuse_linear_weights(weights, **kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
161165
pass
162166

167+
@staticmethod
168+
def custom_op():
169+
"""Unified custom kernel entry-point for quantized linear."""
170+
return torch.ops.auto_deploy.custom_quant_linear
171+
172+
@staticmethod
173+
def build_custom_kwargs_for_linear(
174+
scale_getattrs: Dict[str, Node],
175+
) -> Dict[str, object]:
176+
"""Default: no extra kwargs. Each impl overrides to pass the right inputs/scales/zps/format."""
177+
return {}
178+
163179

164180
class FP8QuantizationImpl(QuantizationImpl):
165181
@staticmethod
@@ -180,6 +196,20 @@ def scale_names() -> List[str]:
180196
def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
181197
return {"input_scale": torch.tensor(1.0), "weight_scale": torch.tensor(1.0)}
182198

199+
@staticmethod
200+
def build_custom_kwargs_for_linear(
201+
scale_getattrs: Dict[str, Node],
202+
) -> Dict[str, object]:
203+
# FP8 custom op contract:
204+
# input_scale=[tensor], weight_scale=[tensor], input_zp=[], weight_zp=[], format_type=FORMAT_FP8
205+
return dict(
206+
input_scale=[scale_getattrs["input_scale"]],
207+
weight_scale=[scale_getattrs["weight_scale"]],
208+
input_zp=[],
209+
weight_zp=[],
210+
format_type=FORMAT_FP8,
211+
)
212+
183213
@staticmethod
184214
def load_hook(state_dict, prefix, *args, weight_name):
185215
if weight_name in state_dict:
@@ -264,6 +294,29 @@ def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
264294
"alpha": torch.tensor(1.0 / 6.0),
265295
}
266296

297+
@staticmethod
298+
def build_custom_kwargs_for_linear(
299+
scale_getattrs: Dict[str, Node],
300+
) -> Dict[str, object]:
301+
"""
302+
Contract:
303+
custom_quant_linear(
304+
x, Wq, bias,
305+
input_scale=[s_in2],
306+
weight_scale=[weight_scale_cutlass_uint8, alpha_fused],
307+
input_zp=[],
308+
weight_zp=[],
309+
format_type=FORMAT_NVFP4
310+
)
311+
"""
312+
return dict(
313+
input_scale=[scale_getattrs["input_scale"]],
314+
weight_scale=[scale_getattrs["weight_scale"], scale_getattrs["alpha"]],
315+
input_zp=[],
316+
weight_zp=[],
317+
format_type=FORMAT_NVFP4,
318+
)
319+
267320
@staticmethod
268321
def load_hook(state_dict, prefix, *args, weight_name):
269322
if weight_name in state_dict:

0 commit comments

Comments
 (0)