Skip to content

Commit f2c609b

Browse files
committed
split linear and bmm quantization
Signed-off-by: Frida Hou <[email protected]>
1 parent 6dc658c commit f2c609b

File tree

3 files changed

+62
-29
lines changed

3 files changed

+62
-29
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ transforms:
3939
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
4040
optimize_rope:
4141
stage: pattern_matcher
42-
quantize_from_config:
42+
quantize_linear_from_config:
43+
stage: pattern_matcher
44+
quantize_bmm_from_config:
4345
stage: pattern_matcher
4446
quantize_from_graph:
4547
stage: pattern_matcher

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,12 @@ def _insert_quantized_bmm(
8787
node: Node,
8888
quantization_impl: QuantizationImpl,
8989
is_quantized_graph: bool = False,
90-
):
91-
"""Replaces the bmm node with a new quantized bmm node."""
90+
) -> bool:
91+
"""Replace a bmm op with its quantized equivalent and wire scales/state_dict hooks.
92+
93+
Returns:
94+
True if quantization was applied; False if skipped (e.g., unknown shape).
95+
"""
9296
weight_node = node.args[1]
9397

9498
# Weight is a parameter
@@ -140,7 +144,7 @@ def get_scale_name(scale_name):
140144

141145
else:
142146
# If we can't determine the shape, skip quantization
143-
return
147+
return False
144148

145149
# Common logic for both parameter and dynamic tensor cases
146150
# Register scales in the target module
@@ -163,16 +167,12 @@ def get_scale_name(scale_name):
163167
# Update node arguments and kwargs
164168
scale_values = [scales[scale_name] for scale_name in quantization_impl.scale_names()]
165169
node.args = (*node.args, *scale_values)
170+
return True
166171

167172

168-
@TransformRegistry.register("quantize_from_config")
169-
class QuantizationFromConfig(BaseTransform):
170-
"""
171-
Quantize linear and BMM ops using a quantization config.
172-
173-
Replaces eligible ops with quantized equivalents based on the quantization algorithm
174-
and exclude patterns defined in the config.
175-
"""
173+
@TransformRegistry.register("quantize_linear_from_config")
174+
class LinearQuantizationFromConfig(BaseTransform):
175+
"""Quantize eligible linear ops per quant config (algo + exclude patterns)."""
176176

177177
def _apply(
178178
self,
@@ -182,38 +182,69 @@ def _apply(
182182
shared_config: SharedConfig,
183183
) -> Tuple[GraphModule, TransformInfo]:
184184
quant_config = factory.get_quant_config()
185-
if not quant_config:
185+
if not quant_config or not quant_config.get("quant_algo"):
186186
return gm, TransformInfo(
187187
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
188188
)
189-
quant_algo = quant_config.get("quant_algo", None)
190-
excluded_patterns = quant_config.get("exclude_modules", [])
191-
if not quant_algo:
189+
190+
quant_algo = quant_config["quant_algo"]
191+
excluded = quant_config.get("exclude_modules", [])
192+
193+
num_matches = 0
194+
impl = QuantizationImpl.create(quant_algo, is_bmm=False)
195+
196+
for n in gm.graph.nodes:
197+
# Only consider linear ops; skip if excluded
198+
if not is_linear_op(n, include_quantization=False):
199+
continue
200+
if should_skip_quantization(n, excluded):
201+
continue
202+
203+
_insert_quantized_linear(gm, n, impl, is_quantized_graph=False)
204+
num_matches += 1
205+
206+
info = TransformInfo(
207+
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
208+
)
209+
return gm, info
210+
211+
212+
@TransformRegistry.register("quantize_bmm_from_config")
213+
class BMMQuantizationFromConfig(BaseTransform):
214+
"""Quantize eligible BMM ops per quant config (algo + exclude patterns)."""
215+
216+
def _apply(
217+
self,
218+
gm: GraphModule,
219+
cm: CachedSequenceInterface,
220+
factory: ModelFactory,
221+
shared_config: SharedConfig,
222+
) -> Tuple[GraphModule, TransformInfo]:
223+
quant_config = factory.get_quant_config()
224+
if not quant_config or not quant_config.get("quant_algo"):
192225
return gm, TransformInfo(
193226
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
194227
)
195228

229+
quant_algo = quant_config["quant_algo"]
230+
excluded = quant_config.get("exclude_modules", [])
231+
196232
num_matches = 0
233+
impl = QuantizationImpl.create(quant_algo, is_bmm=True)
197234

198235
for n in gm.graph.nodes:
199-
if should_skip_quantization(n, excluded_patterns):
236+
if not is_bmm_op(n):
237+
continue
238+
# Reuse common exclusion rule (supports Node or param-name string)
239+
if should_skip_quantization(n, excluded):
200240
continue
201241

202-
if is_linear_op(n, include_quantization=False):
203-
impl = QuantizationImpl.create(quant_algo, is_bmm=False)
204-
_insert_quantized_linear(gm, n, impl, False)
205-
num_matches += 1
206-
207-
# TODO: Make _insert_quantized_bmm return a bool and increment only on success
208-
elif is_bmm_op(n):
209-
impl = QuantizationImpl.create(quant_algo, is_bmm=True)
210-
_insert_quantized_bmm(gm, n, impl, False)
242+
if _insert_quantized_bmm(gm, n, impl, is_quantized_graph=False):
211243
num_matches += 1
212244

213245
info = TransformInfo(
214246
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
215247
)
216-
217248
return gm, info
218249

219250

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_quantization(quant_config, atol, rtol, num_p_og):
7373
gm_transformed = InferenceOptimizer(
7474
DummyFactory(quant_config),
7575
{
76-
"quantize_from_config": {
76+
"quantize_linear_from_config": {
7777
"stage": "pattern_matcher",
7878
},
7979
},
@@ -155,7 +155,7 @@ def test_bmm_quantization(quant_config, atol, rtol, num_p_og, model_class):
155155
gm_transformed = InferenceOptimizer(
156156
DummyFactory(quant_config),
157157
{
158-
"quantize_from_config": {
158+
"quantize_bmm_from_config": {
159159
"stage": "pattern_matcher",
160160
},
161161
},

0 commit comments

Comments
 (0)