-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[#5861][autodeploy] Quantization Transforms with Inheritance #7227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: Fridah-nv <[email protected]> minor Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]> minor Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>
…ationImpl Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]> rename torch cumstom op Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]> Update fusion to pass in args instead of kwrargs Signed-off-by: Frida Hou <[email protected]>
📝 WalkthroughWalkthroughSplits config-driven quantization into separate linear and BMM transforms; adds FP8 and NVFP4 fake-quant custom ops and pattern-matcher fusion; introduces a quantization-aware fusion framework and sharding mixins (TPShardingInfo.from_node); adjusts fake-mode detection, quantization contracts, configs, and tests. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Config as default.yaml
participant Optim as InferenceOptimizer
participant PM as PatternMatcher
participant QT as QuantizeTransforms
participant FU as FusionTransforms
participant FQ as FuseQuant
Config->>Optim: load transforms (pattern_matcher, post_load_fusion)
Optim->>PM: run quantize_linear_from_config
PM->>QT: insert custom FP8/FP4 linear ops (custom_op tails)
Optim->>PM: run quantize_bmm_from_config
PM->>QT: insert BMM quant (if supported)
Optim->>FU: post_load_fusion (fuse_gemms,fuse_fp8_gemms,fuse_fp4_gemms)
FU->>FU: group by shared activation and insert fused quant GEMMs
Optim->>FQ: post_load_fusion fuse_quant
FQ->>FQ: pattern-match reference fake-quant -> fused quant ops
sequenceDiagram
autonumber
participant GM as GraphModule
participant Mix as QuantizationFusionMixin
participant Impl as FuseFP8/FuseFP4
participant Attr as get_attr buffers
GM->>Mix: locate linear nodes with same parent
Mix->>Mix: check_same_children(all target_op?)
alt all match
Mix->>Impl: fuse_rule(weights, scales)
Impl-->>Mix: fused_weight + fused_buffers
Mix->>GM: register fused param and buffers
GM->>Attr: create get_attr nodes for scales
Mix->>GM: insert fused quant op with tail args
else mismatch
Mix-->>GM: skip group
end
sequenceDiagram
autonumber
participant TP as TPShardingInfo.from_node
participant Core as _insert_sharded_matmul
participant QMixin as QuantizationShardingMixin.quantization_cb
TP->>Core: apply(node) with quantization_cb
Core->>QMixin: delegate quantization handling
QMixin->>Core: gather buffers, compute sharded scales, register, add load hook
Core-->>TP: sharded matmul wired with quantization-aware buffers
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1)
1-1
: Add NVIDIA Apache-2.0 copyright header (required by repo guidelines).All Python sources must carry the NVIDIA Apache-2.0 header. Please prepend it here.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)
1-1
: Add NVIDIA Apache-2.0 header (tests are Python files too).Please add the standard header at the top.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1)
265-283
: Recompile GraphModule after pattern rewrites.
ADPatternMatcherPass.apply
mutatesgm.graph
, but the Python code of the GraphModule isn’t regenerated automatically. Recompiling avoids stale code and ensures subsequent passes (and any eager execution) see the updated graph.- num_matches = patterns.apply(gm.graph) + num_matches = patterns.apply(gm.graph) + if num_matches: + gm.graph.eliminate_dead_code() + gm.recompile()tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3)
1-1
: Add NVIDIA Apache-2.0 header.+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
53-55
: Fix potential IndexError when readingnode.users
.
user = list(node.users.keys())[0]
will throw if there are zero users. Guard the access.- user = list(node.users.keys())[0] - if len(node.users) == 1 and is_quantized_op(user): - user.replace_all_uses_with(node) + if len(node.users) == 1: + sole_user = next(iter(node.users)) + if is_quantized_op(sole_user): + sole_user.replace_all_uses_with(node)
174-211
: Recompile and set skipped flag accurately in LinearQuantizationFromConfig.
- Set
skipped
based onnum_matches == 0
.- Recompile after mutations to materialize changes in
GraphModule
code.- info = TransformInfo( - skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True - ) - return gm, info + if num_matches: + gm.graph.eliminate_dead_code() + gm.recompile() + return gm, TransformInfo( + skipped=(num_matches == 0), + num_matches=num_matches, + is_clean=False, + has_valid_shapes=True, + )
🧹 Nitpick comments (21)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
252-254
: Make linear-op detection robust when custom ops are unavailableDirectly accessing
torch.ops.auto_deploy.torch_fake_quant_fp{8,4}_linear
can raise AttributeError on builds where these ops aren’t registered. Guard the additions so import-time doesn’t fail on older runtimes.Apply this diff:
if include_quantization: lin_ops.update(QUANT_LINEAR_OPS) - lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp8_linear]) - lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp4_linear]) + for _name in ("torch_fake_quant_fp8_linear", "torch_fake_quant_fp4_linear"): + _op = getattr(torch.ops.auto_deploy, _name, None) + if _op is not None: + lin_ops.add(_op)
1-1
: Missing NVIDIA Apache-2.0 headerPer coding guidelines, prepend the NVIDIA Apache-2.0 copyright header (year 2025) to this file.
Do you want me to apply a repo-standard header template across all changed Python files?
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)
37-73
: Fake-mode detection: good coverage; minor resilience nitsThe priority (TracingContext → active dispatch stack → inputs) looks correct and fixes fragility in upstream detection. Two small nits:
- If
inputs
isNone
,pytree.tree_leaves(None)
returns[None]
; harmless but you can avoid a tiny branch cost.- If multiple FakeTensorModes are stacked, you’re already iterating the reversed stack (most recent first) — nice.
Optional micro-tidy:
- flat_inputs = pytree.tree_leaves(inputs) + flat_inputs = () if inputs is None else pytree.tree_leaves(inputs)
75-77
: Monkeypatch scope: limit to the pattern-matcher passReplacing
torch._dynamo.utils.detect_fake_mode
globally can have side effects in unrelated compilation paths. Patch only duringADPatternMatcherPass.apply()
and restore afterward.Apply this diff:
-# Replace the function used as a context manager -torch._dynamo.utils.detect_fake_mode = ad_detect_fake_mode +_ORIG_DETECT_FAKE_MODE = getattr(torch._dynamo.utils, "detect_fake_mode", None)And update the pass:
class ADPatternMatcherPass(PatternMatcherPass): @@ def apply(self, graph: Union[torch.fx.Graph, GraphModule]) -> int: """Apply pattern matcher with unsupported_input_tensor patch to bypass meta tensor check.""" - with _patch_unsupported_input_tensor(): - return super().apply(graph) + with _patch_unsupported_input_tensor(): + # temporarily patch fake-mode detection + orig = getattr(torch._dynamo.utils, "detect_fake_mode", None) + torch._dynamo.utils.detect_fake_mode = ad_detect_fake_mode + try: + return super().apply(graph) + finally: + if orig is not None: + torch._dynamo.utils.detect_fake_mode = orig
1-1
: Missing NVIDIA Apache-2.0 headerPlease prepend the NVIDIA Apache-2.0 header (2025) at the top of this file.
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
1-1
: Missing NVIDIA Apache-2.0 headerAdd the standard NVIDIA Apache-2.0 header (2025).
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)
112-114
: Set sharded params to requires_grad=False: ensure consistency with BMM pathHere we set
requires_grad=False
for linear shards (good for inference). InBMMShardingInfo.apply
, sharded parameters are created withrequires_grad=True
. For consistency across inference transforms, prefer disabling grads there too.If you agree, I can provide a follow-up diff in the BMM path.
391-444
: FP8/FP4 TP sharding classes: interface looks sound
- FP8: pass-through scales and no-op load hook make sense.
- FP4: sharding both
weight_scale
and handling the*_scale
state_dict suffix in the pre-hook is correct.Two small nits:
- Consider documenting required buffers explicitly in class docstrings (e.g., FP4 expects
input_scale
,weight_scale
,alpha
on the owning submodule).- Minor naming:
sharded_uint8_weight_shape
could just beweight_shape
for consistency (docstring already clarifies it’s the sharded packed shape).
446-460
: Resolver fallthrough handlingThe
except Exception: pass
in_resolve_tp_cls_from_node
hides real issues (e.g., NameError on missing ops). Log at debug at least, so we can diagnose resolver paths.Minimal change:
- except Exception: - pass + except Exception as e: # pragma: no cover + ad_logger.debug(f"TP resolver predicate failed: {e}")
1-1
: Missing NVIDIA Apache-2.0 headerPlease add the required header (2025).
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (1)
1-1
: Missing NVIDIA Apache-2.0 header in test fileTests are also subject to the header requirement per guidelines.
I can sweep-add the header to all modified tests if desired.
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
94-95
: Order of post-load fusions: consider movingfuse_quant
earlier.
fuse_quant
rewrites fake-quant reference ops into fused kernels. Running it earlier in post-load-fusion can:
- Increase the chance for downstream kernel-level fusions to act on the fused ops,
- Reduce pattern mismatch after other structural fusions.
Recommendation: place
fuse_quant
ahead of GEMM/RMSNorm fusions unless there is a known dependency the other way around.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)
12-16
: Remove duplicate/unused constants; use a single SCALING_VECTOR_SIZE.
FORMAT_FP8
andFORMAT_NVFP4
are unused here.- Both
scaling_vector_size
andSCALING_VECTOR_SIZE
represent the same value (16). Keep one for consistency.- scaling_vector_size = 16 -FORMAT_FP8 = 0 -FORMAT_NVFP4 = 1 - -SCALING_VECTOR_SIZE = 16 # NVFP4 block size along K +SCALING_VECTOR_SIZE = 16 # NVFP4 block size along KAnd replace uses of
scaling_vector_size
below:- weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize( - weight, weight_scale_2, scaling_vector_size, False - ) + weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize( + weight, weight_scale_2, SCALING_VECTOR_SIZE, False + )
156-201
: LGTM with a minor suggestion for negative-case coverage.This validates CUTLASS-scale + alpha wiring across dtypes with appropriate tolerances. Consider adding a negative test asserting failure when K is not a multiple of
SCALING_VECTOR_SIZE
to lock the contract.tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1)
30-46
: Be explicit with op overloads for consistency.Elsewhere you used
.default
for FP8 pattern calls; mirror that for replacements to avoid ambiguity.-return torch.ops.auto_deploy.torch_quant_fp8_linear( +return torch.ops.auto_deploy.torch_quant_fp8_linear.default( x, w_fp8, None, input_scale=input_scale, weight_scale=weight_scale) ... -return torch.ops.auto_deploy.torch_quant_fp8_linear( +return torch.ops.auto_deploy.torch_quant_fp8_linear.default( x, w_fp8, bias, input_scale=input_scale, weight_scale=weight_scale)And similarly for FP4:
-return torch.ops.auto_deploy.torch_quant_fp4_linear( +return torch.ops.auto_deploy.torch_quant_fp4_linear.default( x, w_fp4, bias=None, input_scale=input_scale, weight_scale=weight_scale, alpha=alpha) ... -return torch.ops.auto_deploy.torch_quant_fp4_linear( +return torch.ops.auto_deploy.torch_quant_fp4_linear.default( x, w_fp4, bias=bias, input_scale=input_scale, weight_scale=weight_scale, alpha=alpha)Also applies to: 67-84
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)
86-171
: Consider deferring graph recompilation to the transform level (avoid per-node overhead).
_insert_quantized_bmm
and_insert_quantized_linear
appropriately avoid recompiling per node. The transform-level recompile suggested above keeps performance in check while ensuring correctness. No changes needed here beyond the transform finalize steps.tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
30-33
: Consider consolidating FORMAT_ constants with existing imports*The TODO comment indicates these format constants should be imported from a common location. Since
torch_quant.py
also defines the same constants (FORMAT_FP8 = 0, FORMAT_NVFP4 = 1), consider importing them from there to maintain a single source of truth.-# TODO: put the ENUMs in the same place and import it -FORMAT_FP8 = 0 -FORMAT_NVFP4 = 1 +from ..custom_ops.torch_quant import FORMAT_FP8, FORMAT_NVFP4
143-154
: Consider adding type hints for better clarityThe new methods would benefit from more specific type hints for the
object
return type to improve type safety and IDE support.-def build_custom_kwargs_for_linear( - scale_getattrs: Dict[str, Node], -) -> Dict[str, object]: +def build_custom_kwargs_for_linear( + scale_getattrs: Dict[str, Node], +) -> Dict[str, Union[List[Node], List]]:
194-210
: Inconsistent comment: torch_fake_quant_fp8_linear vs FP8The docstring mentions "torch_fake_quant_fp8_linear" but this is the FP8QuantizationImpl class. Also, the example pattern in the comment doesn't match the actual implementation which uses Node objects, not raw arguments.
def build_custom_args_for_linear( # renamed to reflect args scale_getattrs: Dict[str, Node], ) -> Tuple[object, ...]: """ - Build the *positional* tail for torch_fake_quant_fp8_linear: + Build the *positional* tail for FP8 quantized linear: (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list)) - - We pass bias=None to match the exported pattern: - torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], []) """tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
255-255
: Consider using modulo check with error messageThe assertion could provide more context about the actual value of K when it fails.
-assert K % 16 == 0, "NVFP4 requires K to be a multiple of 16" +assert K % 16 == 0, f"NVFP4 requires K to be a multiple of 16, got K={K}"tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)
214-216
: Ensure consistent node filteringThe condition checks for
partial(is_op, ops=self.target_op)
but the initial collection usesis_op(node, self.target_op)
. While functionally equivalent, consider using the same pattern for clarity.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (13)
tensorrt_llm/_torch/auto_deploy/config/default.yaml
(3 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
(7 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
(5 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
(2 hunks)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
(2 hunks)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
(7 hunks)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
(6 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧬 Code graph analysis (9)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(166-198)torch_fake_quant_fp8_linear
(202-212)torch_fake_quant_fp4_linear
(216-274)torch_fake_quant_fp4_linear
(278-287)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (3)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface
(12-70)tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)
ADPatternMatcherPass
(104-110)register_ad_pattern
(142-225)apply
(107-110)tensorrt_llm/_torch/auto_deploy/transform/interface.py (3)
BaseTransform
(139-376)SharedConfig
(51-57)TransformRegistry
(379-407)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
fp8_compatible
(29-30)fp4_compatible
(33-34)trtllm_ops_available
(37-38)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(166-198)torch_fake_quant_fp8_linear
(202-212)torch_fake_quant_fp4_linear
(216-274)torch_fake_quant_fp4_linear
(278-287)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
fp4_global_scale
(62-64)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
TPShardingInfo
(221-264)from_node
(229-234)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
cutlass_fp4_scale_to_modelopt_fp4_scale
(47-59)custom_op
(162-164)custom_op
(236-238)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (5)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
is_op
(183-206)extract_param_names_from_lin_node
(149-170)get_op_overload_packet
(173-180)is_linear_op
(240-254)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
build_custom_args_for_linear
(150-153)build_custom_args_for_linear
(195-210)build_custom_args_for_linear
(291-306)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface
(12-70)tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)
register
(385-392)BaseTransform
(139-376)tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
cuda_memory_tracker
(10-26)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (8)
build_custom_args_for_linear
(150-153)build_custom_args_for_linear
(195-210)build_custom_args_for_linear
(291-306)custom_op
(162-164)custom_op
(236-238)QuantizationImpl
(72-153)create
(76-105)should_skip_quantization
(471-484)tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
TransformRegistry
(379-407)register
(385-392)BaseTransform
(139-376)get
(395-397)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
is_linear_op
(240-254)is_bmm_op
(257-264)tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ModelFactory
(15-207)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op
(183-206)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
build_custom_args_for_linear
(128-130)build_custom_args_for_linear
(294-301)build_custom_args_for_linear
(340-347)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(166-198)torch_fake_quant_fp8_linear
(202-212)torch_fake_quant_fp4_linear
(216-274)torch_fake_quant_fp4_linear
(278-287)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7)
QuantizationImpl
(72-153)cutlass_fp4_scale_to_modelopt_fp4_scale
(47-59)modelopt_fp4_scale_to_cutlass_fp4_scale
(35-44)scale_names
(118-120)scale_names
(173-174)scale_names
(246-247)scale_names
(425-426)tensorrt_llm/_torch/modules/linear.py (1)
split_dim
(48-49)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op
(183-206)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (21)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)
18-18
: PyTorch version guard for TracingContext import
from torch._guards import TracingContext
is internal and version-sensitive. If CI still runs on older 2.x, this import will fail at import-time.Consider guarding the import:
-from torch._guards import TracingContext +try: + from torch._guards import TracingContext +except Exception: # pragma: no cover + TracingContext = None # Fallback; ad_detect_fake_mode should branch if NoneAnd in
ad_detect_fake_mode
, checkif TracingContext and (context := TracingContext.try_get()): ...
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)
106-114
: Good migration to node-aware TPShardingInfo factorySwitching to
TPShardingInfo.from_node(n, …)
enables quantization-aware subclassing. This looks correct and makes future extensions easier.
308-316
: Quantization-aware two-way sharding: ensure fake-quant linears resolve to the right subclassThis path relies on
TPShardingInfo.from_node(n, …)
correctly detecting FP8/FP4 nodes. Currently, the resolver insharding_utils.py
only checks fused ops (torch_quant_linear_fp{8,4}
). If the graph contains fake-quant ops (torch_fake_quant_fp{8,4}_linear
), it will fall back to the base TP class and skip scale sharding.I’ve proposed a fix in
sharding_utils.py
to include fake-quant ops inTP_SHARDING_RULES
. See that comment for a concrete diff.tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)
65-68
: Extension point for quantization-aware sharding is well designedAdding
quantization_cb
to_insert_sharded_matmul
is a clean way to decouple quantization specifics from the core sharding flow.
148-158
: Quantization callback invocation: good integration pointThe call to
quantization_cb
after weight/bias sharding ensures scales and load hooks are set while shapes are known. This ordering is correct.
378-389
: _shard_fp4_weight_scale: shape reconstruction math—request a quick validationThe reconstruction of the original weight shape assumes:
weight_shape_original[dim] *= world_size
- last dim doubled (
* 2
) for unpacking FP4If the shard is along columns (dim=1), doubling the last dimension aligns with your packed uint8 scheme. Please verify this with a shape example in tests (e.g., N×K_packed with K_packed=K/2) to avoid off-by-one on block edges.
I can add a focused unit test exercising
dim=0
vsdim=1
with uneven multiples of 128/16 to ensure correct slicing of scales.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2)
76-79
: Config key rename LGTMRenaming to
quantize_linear_from_config
aligns with the split linear/BMM transforms. Looks good.
158-161
: BMM quantization config key rename LGTM
quantize_bmm_from_config
key matches the new transform; no issues spotted.tensorrt_llm/_torch/auto_deploy/config/default.yaml (2)
80-85
: Re-enabling GEMM fusions can increase memory pressure; consider guardrails.
fuse_gemms
,fuse_fp4_gemms
, andfuse_fp8_gemms
were previously disabled due to OOM risk. If we re-enable them by default, consider:
- A config toggle to disable at runtime,
- Conditioning on batch/hidden sizes,
- Or staging behind a perf profile flag.
If recent CI runs haven’t covered large models, please schedule one to confirm memory headroom with these fusions enabled.
48-51
: Transforms split verified – registry and config updated
- Registered in
quantization.py
:
•@TransformRegistry.register("quantize_linear_from_config")
at line 174
•@TransformRegistry.register("quantize_bmm_from_config")
at line 213- Config example in
config/default.yaml
:
•quantize_linear_from_config
at line 48
•quantize_bmm_from_config
at line 50
• No remaining references to the oldquantize_from_config
keytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)
114-143
: LGTM: FP8 fused vs. unified parity test is tight and representative.Good coverage with/without bias, explicit scale wiring ([in_s], [w_s]), and strict tolerances.
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)
80-84
: LGTM: switch to unified custom op tail args is correct.Appending positional tail produced by
build_custom_args_for_linear
aligns with the new fake-quant op contracts.tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
103-104
: Good improvement: Using custom_op() for unified quantization detectionThe change from
target_op()
tocustom_op()
standardizes the node matching logic for quantized operations, providing a cleaner separation between target operations and custom kernel entry points.tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
165-198
: Well-structured FP8 quantization implementationThe reference implementation is clear and well-documented, with proper error handling for dtype mismatches and missing scales.
215-274
: Complex but well-documented FP4 quantization logicThe NVFP4 implementation correctly handles the complex per-block quantization scheme with proper scale vector handling and shape transformations. The comments effectively explain the multi-step process.
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (6)
86-97
: Clean utility function for verifying homogeneous childrenThe
check_same_children
function provides a robust way to ensure all child nodes match the expected type before fusion, preventing mixed-precision issues.
99-131
: Excellent abstraction with the QuantizationFusionMixinThe mixin pattern effectively captures the common fusion logic while allowing subclasses to customize their specific quantization formats. The clear documentation of required attributes and methods makes this easy to extend.
153-157
: Good error handling with informative loggingThe try-except block properly catches NotImplementedError and logs a warning with context about which operations couldn't be fused.
282-283
: Add dtype validation for weight fusionGood assertion to ensure only FP8 quantized weights are being fused. This prevents unexpected behavior with mixed precision.
266-311
: Well-implemented FP8 fusion with proper scalingThe FP8 fusion correctly handles weight scale recalculation by finding the maximum scale and re-quantizing accordingly. The implementation properly preserves numerical accuracy while enabling fusion.
313-356
: FP4 fusion correctly handles per-block scalesThe FP4 implementation properly concatenates the per-block scale vectors along with the weights, maintaining the quantization structure needed for the fused operation.
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (4)
1-1
: Add NVIDIA SPDX+Apache-2.0 header (license compliance).This source file lacks the required header.
Apply this diff at the very top:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
475-489
: Replace built-in generics list[str] with typing.List[str] for Python 3.8 support.The codebase targets Python 3.8+, which doesn’t support PEP 585 generics on built-ins.
Apply this diff:
-def should_skip_quantization( - node_or_name: Union[Node, str], - excluded_patterns: list[str], -) -> bool: +def should_skip_quantization( + node_or_name: Union[Node, str], + excluded_patterns: List[str], +) -> bool:
491-518
: Fix type annotations for Python 3.8; also widen return type to Optional.Use List[str] instead of list[str]; get_scales_and_type_from_node can return None for scales.
Apply this diff:
-def extract_scales_from_node(node: Node, scale_names: list[str]) -> Dict[str, Optional[Node]]: +def extract_scales_from_node(node: Node, scale_names: List[str]) -> Dict[str, Optional[Node]]: @@ -def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]: +def get_scales_and_type_from_node(node: Node) -> Tuple[Optional[Dict[str, Node]], str]:
312-359
: Ensure consistentinput_scale
andalpha
semantics across both load pathsModelOpt’s branch currently computes
alpha = 1 / (s_w2 * s_in)and leaves
input_scale
ass_in
, whereas the HF branch usesalpha = s_w2 * s_in input_scale = 1 / s_inThe custom op expects
input_scale
to always represent1/x
andalpha
to bes_in * s_w2
. Please update the ModelOpt path accordingly:• File:
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
Static methodload_hook
, around lines 320–335Proposed diff:
# ModelOpt quantized graph path if weight.dtype != torch.uint8: … - state_dict[alpha_name] = 1 / (weight_scale_2 * state_dict[input_scale_name]) + state_dict[alpha_name] = weight_scale_2 * state_dict[input_scale_name] + state_dict[input_scale_name] = 1 / state_dict[input_scale_name]If there are consumers relying on the old convention, consider gating this change under a version/compatibility flag. Let me know if you’d like assistance wiring that up.
♻️ Duplicate comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
1-1
: Add NVIDIA SPDX+Apache-2.0 header (license compliance).All Python source files must start with the NVIDIA copyright header. Please add it before any imports.
Apply this diff at the top of the file:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
52-54
: Use typing.Optional instead of PEP 604 unions for Python 3.8 compatibility.The “A | B” syntax requires Python 3.10+. Project guideline targets Python 3.8+.
Apply this diff:
def _nvfp4_get_weights_scaling_factor( input: torch.Tensor, block_size: int, - weights_scaling_factor_2: torch.Tensor | None = None, + weights_scaling_factor_2: Optional[torch.Tensor] = None, keep_high_precision: bool = False, ): ... def _quantize_nvfp4( input: torch.Tensor, block_size: int, - weights_scaling_factor_2: torch.Tensor | None = None, + weights_scaling_factor_2: Optional[torch.Tensor] = None, ):Also applies to: 107-108
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
294-304
: Docstring mentions fp8 operator name in FP4 section; correct to fp4.Copy/paste error; use torch_fake_quant_fp4_linear in the docstring and example.
Apply this diff:
- Build the *positional* tail for torch_fake_quant_fp8_linear: + Build the *positional* tail for torch_fake_quant_fp4_linear: @@ - torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], []) + torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
🧹 Nitpick comments (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
39-46
: Remove unused parameter from _dequant_weight_fp8 and its call site.out_features is not used; simplify signature and call.
Apply this diff:
def _dequant_weight_fp8( weight_fp8: torch.Tensor, weight_scale: torch.Tensor, - out_features: int, dtype: torch.dtype, ) -> torch.Tensor: return weight_fp8.to(dtype) * weight_scale
- weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype) + weight_deq = _dequant_weight_fp8(weight_quantized, s_w, in_dtype)Also applies to: 195-195
109-117
: Docstring return mismatch (function returns 2 values, docstring says 3).Update the docstring to reflect the actual 2-tuple return: (packed_weight, q_per_block_scale).
Apply this diff:
- Returns: - tuple: Contains quantized data, quantized per block scaling factor, and per block scaling factor. + Returns: + Tuple[torch.Tensor, torch.Tensor]: (packed_weight, q_per_block_scale)
88-101
: Clarify tie-breaking mask logic for FP4 rounding.Multiplying a boolean equality by a uint8 mask works by accident. Prefer explicit boolean logic for readability and safety.
Apply this diff for clarity:
- # Define mask to perform rounding - mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device) + # Define boolean tie mask: True on odd indices to round up ties + mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8, device=device).bool() ... - round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1) + round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) & mask, dim=-1)
217-276
: NVFP4 scale semantics: input_scale/alpha naming vs. usage is inconsistent with helper contract.The code treats input_scale as the inverse scale (inv_x) and expects alpha = s_in2 * s_w2. Ensure upstream load_hooks and defaults produce these exact semantics, or normalize inside this op to avoid subtle bugs across weight-load paths.
Would you like me to provide a normalization shim at the start of this op that accepts either representation and converts to (inv_x, alpha = s_in2*s_w2) consistently?
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
30-33
: Avoid duplicating FORMAT_ enums; centralize and import.*These enums also exist in custom_ops.torch_quant. Duplicates risk divergence.
Consider moving FORMAT_FP8/NVFP4 to a single shared module (e.g., tensorrt_llm/_torch/auto_deploy/common/quant_formats.py) and import from there.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)
39-56
: Avoid tracking gradients when reassigning quantized weights/bias in tests.Wrap parameter/buffer reassignments in torch.no_grad() to be explicit and avoid autograd hooks in some envs.
Apply this diff:
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - device = self.weight.device - weight_scale = torch.max(torch.abs(self.weight)).to(torch.float).to(device) / FP8_MAX - self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn)) - self.register_buffer( - "input_scale", torch.tensor(1.0, device=self.weight.device, dtype=torch.float) - ) - self.register_buffer("weight_scale", weight_scale) - if self.bias is not None: - self.bias = nn.Parameter(self.bias.to(torch.half)) + device = self.weight.device + with torch.no_grad(): + weight_scale = torch.max(torch.abs(self.weight)).to(torch.float).to(device) / FP8_MAX + self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn)) + if self.bias is not None: + self.bias = nn.Parameter(self.bias.to(torch.half)) + self.register_buffer( + "input_scale", torch.tensor(1.0, device=self.weight.device, dtype=torch.float) + ) + self.register_buffer("weight_scale", weight_scale)tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
80-86
: New post-load fusion switches: fuse_gemms, fuse_fp4_gemms, fuse_fp8_gemms, and fuse_quant.Good separation of generic and quantized GEMM fusions. Ensure docs mention these toggles for users.
Would you like me to draft a short README snippet describing these flags and the expected effects?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/auto_deploy/config/default.yaml
(2 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
(7 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
(7 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧠 Learnings (2)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧬 Code graph analysis (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
torch_fake_quant_fp8_linear
(168-200)torch_fake_quant_fp8_linear
(204-214)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
cutlass_fp4_scale_to_modelopt_fp4_scale
(47-59)custom_op
(166-168)custom_op
(240-242)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op
(183-206)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
build_custom_args_for_linear
(128-130)build_custom_args_for_linear
(294-301)build_custom_args_for_linear
(340-347)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(168-200)torch_fake_quant_fp8_linear
(204-214)torch_fake_quant_fp4_linear
(218-276)torch_fake_quant_fp4_linear
(280-289)
🔇 Additional comments (4)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (3)
36-36
: LGTM: stable FP8 constant for tests.Using torch.finfo for FP8 max is appropriate and future-proof.
96-99
: LGTM: switching target linear layers to FakeFP8Linear for FP8 scenarios.This isolates the fusion test from HW-dependent kernels and exercises the FP8 fuse path.
Also applies to: 111-114, 135-137, 159-162, 191-193
286-288
: LGTM: enabling fuse_fp8_gemms in test optimizer config.Matches the intent of exercising the post_load_fusion path for FP8.
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
48-51
: Rename quantize_from_config → quantize_linear_from_config and add quantize_bmm_from_config: config aligns with code.Configuration looks consistent with the split transforms; no action needed.
Signed-off-by: Frida Hou <[email protected]> add unit test for quant fusion and quant sharding Signed-off-by: Frida Hou <[email protected]>
7d1f402
to
40ef068
Compare
Signed-off-by: Frida Hou <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)
1-1
: Add NVIDIA 2025 copyright header (per repo policy).All source files must start with the NVIDIA header for the current year. Please prepend the SPDX header to comply.
Apply this diff at the top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0
475-479
: Python 3.8 compatibility: replace PEP 585 builtins with typing generics.The repo targets Python 3.8+. Using list[str] requires 3.9+. Switch to List[str] (already imported).
-def should_skip_quantization( - node_or_name: Union[Node, str], - excluded_patterns: list[str], -) -> bool: +def should_skip_quantization( + node_or_name: Union[Node, str], + excluded_patterns: List[str], +) -> bool:tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
110-121
: FP8MLP path casts FP8 weights to float16 during pattern detection.In
_run_pattern_detection_job
the genericelse
branch uses.to(device="cuda", dtype=torch.float16)
for all non-GQA models. Whenmodel_cls == FP8MLP
, this will upcast the FP8-quantized weights to fp16 and break the FP8 reference op (it expectsfloat8_e4m3fn
). This can cause export-time failures or incorrect pattern detection.Fix by adding a dedicated FP8MLP branch that does not cast dtype:
@@ def _run_pattern_detection_job(...): - else: - model = model_cls(num_features, num_features, bias=bias).to( - device="cuda", dtype=torch.float16 - ) + elif model_cls == FP8MLP: + # Keep FP8 quantized params; don't cast dtype. + model = model_cls(num_features, num_features, bias=bias).to("cuda") + else: + model = model_cls(num_features, num_features, bias=bias).to( + device="cuda", dtype=torch.float16 + )tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)
1-5
: Add mandatory NVIDIA copyright header (2025).Per coding guidelines, every Python source must begin with the NVIDIA copyright header for the current year.
Apply this diff at the top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)
510-517
: Update detector to also match custom_op() (align with create()).get_scales_and_type_from_node only matches target_op(); with the new custom_op()-based path, scales won’t be found from nodes calling the custom kernels.
-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]: +def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]: """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc).""" - for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]: - if is_op(node, qtype.target_op()): + for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]: + if is_op(node, [qtype.custom_op(), qtype.target_op()]): return extract_scales_from_node( node, qtype.scale_names() ), qtype.__name__.lower().replace("quantizationimpl", "") return None, "simple"
294-311
: Fix copy-paste error in FP4 docstring (mentions FP8 instead of FP4).The docstring for FP4’s build_custom_args_for_linear references torch_fake_quant_fp8_linear. It should reference the FP4 variant.
- Build the *positional* tail for torch_fake_quant_fp8_linear: + Build the *positional* tail for torch_fake_quant_fp4_linear: @@ - torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], []) + torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
🧹 Nitpick comments (17)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (6)
30-33
: Stabilize and document format constants; mark as Final.These int flags are fine, but make the intent explicit to type checkers and future readers. Also, the TODO suggests centralizing—good follow-up, but marking them Final is a quick win now.
Apply this diff here:
-# TODO: put the ENUMs in the same place and import it -FORMAT_FP8 = 0 -FORMAT_NVFP4 = 1 +# TODO: centralize format constants in a shared module and import from there +from typing import Final # type: ignore[unused-ignore] (ok in local scope if not hoisted) +FORMAT_FP8: Final[int] = 0 +FORMAT_NVFP4: Final[int] = 1Additionally (outside this hunk), add Final to the typing import on Line 2:
from typing import Dict, List, Optional, Tuple, Union, Final
88-89
: Return a clear error for unsupported NVFP4 BMM rather than None.Returning None here can lead to confusing downstream AttributeErrors. Raise a targeted NotImplementedError to fail fast with context.
- "NVFP4": None, # BMM NVFP4 is not supported yet + "NVFP4": None, # BMM NVFP4 is not supported yet } - return quantization_impl_map[quant_type_or_node] + impl = quantization_impl_map[quant_type_or_node] + if quant_type_or_node == "NVFP4" and impl is None: + raise NotImplementedError("BMM NVFP4 is not supported yet") + return impl
101-104
: Broaden node detection to support both new custom_op() and legacy target_op().Only checking custom_op() risks missing older graphs that still use target_op(). Use an OR to be backward compatible without sacrificing the new path.
- ]: - if is_op(quant_type_or_node, q.custom_op()): - return q + ]: + if is_op(quant_type_or_node, [q.custom_op(), q.target_op()]): + return q
184-197
: Remove commented-out format_type to avoid confusion.Since FP8/FP4 have distinct custom kernels now, the commented format_type is misleading.
return dict( input_scale=[scale_getattrs["input_scale"]], weight_scale=[scale_getattrs["weight_scale"]], input_zp=[], weight_zp=[], - # format_type=FORMAT_FP8, )
198-215
: Docstring example: clarify bias handling or drop the example.Doc says “We pass bias=None,” but the example shows positional args without explicitly indicating None. Either annotate which arg is None or omit the example to prevent misreads.
271-293
: Remove commented-out format_type here as well (mirror FP8 change).Keeps the contract focused on the actual inputs the kernels consume now.
return dict( input_scale=[scale_getattrs["input_scale"]], weight_scale=[scale_getattrs["weight_scale"], scale_getattrs["alpha"]], input_zp=[], weight_zp=[], - # format_type=FORMAT_NVFP4, )
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)
12-13
: Ensure custom quant ops are registered for FP8 tests.
FP8MLP
relies onFakeFP8Linear
which callstorch.ops.auto_deploy.torch_fake_quant_fp8_linear
. Importing the custom ops in this test file avoids order-dependent failures when running tests selectively.Add near the top (close to other imports):
+import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
81-92
: Minor: add a short docstring and keep dtype consistent.
FP8MLP
is a test-only module. Adding a one-liner docstring clarifies intent. Also consider keeping the module in half precision where appropriate for inputs, but avoid mutating FP8 params (you already do in_run_job
). No code change required beyond docstring.Example:
class FP8MLP(nn.Module): + """Two-layer MLP using FakeFP8Linear modules to exercise FP8 sharding paths.""" def __init__(self, in_features, out_features, bias=False): super().__init__()
304-315
: Param space coverage note (optional).Pattern-detection is only checked for
world_size=[8]
. If CI time permits, adding a smaller value (e.g., 2 or 4) can catch shape corner cases without significant overhead.
1-1
: Missing NVIDIA copyright header.Per the coding guidelines, prepend the NVIDIA copyright header (2025) to this file.
Add at the very top:
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (4)
1-1
: Add NVIDIA copyright header.All source files (.py included) must carry the NVIDIA copyright header.
Insert above line 1:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
16-24
: Initialize ModelFactory base class and return quant config from the stub.
DummyFactory
inheritsModelFactory
but doesn’t callsuper().__init__
, leaving base fields uninitialized. Some optimizer paths read attributes likeskip_loading_weights
or callget_quant_config()
. Make the stub explicit and safe.Apply this diff:
class DummyFactory(ModelFactory): def __init__(self, quant_config=None): - self._quant_config = quant_config or {} + super().__init__(model="", skip_loading_weights=True) + self._quant_config = dict(quant_config or {}) + def get_quant_config(self) -> dict: + return dict(self._quant_config) + def _build_model(self, device: str): - return + return None def _load_checkpoint(self, model, device): - return + return None
88-130
: Clarify FP4 scale semantics (‘s_in2’ vs ‘inv’).The custom op docstring indicates
input_scale[0] = s_in2
andweight_scale[1] = alpha = s_in2 * s_w2
. Here you passinput_scale_2 = s_in2
andalpha = 1/(s_in2 * s_w2)
. That’s the reciprocal of the docstring’salpha
, but matches the internal variable names (inv_x
, etc.) in the reference implementation.Consider adding a short comment to disambiguate “inv” vs “s2” conventions, or rename buffers to reflect “inv_” semantics if that’s the intended contract. This reduces future confusion should the op’s docs or code be refactored.
Would you like me to propose a consistent naming/comment patch once you confirm which convention we standardize on?
27-35
: Assertion helpers correctly detect rewrite; minor consistency nit.
_has_fused_linear_fp8
checks.default
overload for the reference op, while_has_fused_linear_fp4
checks the packet (no.default
). Both work due tois_op
handlingOpOverloadPacket
, but using.default
in both places improves uniformity.Optional tweak:
- found_ref = any( - is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear) for n in gm.graph.nodes - ) + found_ref = any( + is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear.default) for n in gm.graph.nodes + )tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (3)
259-261
: Add a short docstring for FakeFP8Linear (public test helper).This class is imported by multiple tests; a concise docstring clarifies its contract and limits (e.g., test-only, expects FP8 path).
Apply this diff:
class FakeFP8Linear(nn.Linear): def __init__(self, *args, **kwargs): + """ + Test-only Linear that stores FP8-quantized weights and invokes the FP8 fake-quant op. + Not intended for training/autograd; biases are cast to the input dtype at runtime. + """ super().__init__(*args, **kwargs)
263-265
: Consider keeping float32 weights and storing FP8 weights in a buffer to reduce surprise for code expectingnn.Linear.weight
to be float.Overwriting
self.weight
with FP8 may confuse utilities or checkpoints that assume a float weight tensor. A low-impact alternative: keepself.weight
as-is and storeweight_fp8
as a buffer.Optional refactor:
- self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn)) + # Keep original float weight; store FP8 copy for the fake-quant op. + self.register_buffer("weight_fp8", (self.weight.detach() / weight_scale).to(torch.float8_e4m3fn))and in forward:
- return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default( - x, self.weight, bias, [self.input_scale], [self.weight_scale], [], [] - ) + return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default( + x, self.weight_fp8, bias, [self.input_scale], [self.weight_scale], [], [] + )
272-275
: Use.default
for invokingtorch_fake_quant_fp8_linear
consistentlyThe call sites for
torch.ops.auto_deploy.torch_fake_quant_fp8_linear
are currently inconsistent: some invoke the op directly, while others use the.default
attribute. Our convention is to always use.default
when calling custom ops. Please update the following locations:
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py:273
Changereturn torch.ops.auto_deploy.torch_fake_quant_fp8_linear( x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], [] )to
return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default( x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], [] )tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py:131
Similarly prepend.default
to the op invocation there.This will ensure all custom-op calls follow the same pattern.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (8)
tensorrt_llm/_torch/auto_deploy/config/default.yaml
(2 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
(7 hunks)tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
(1 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
(6 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
(7 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
(1 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
- tensorrt_llm/_torch/auto_deploy/config/default.yaml
- tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
- tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)
Files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (4)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (11)
FakeFP8Linear
(259-275)forward
(74-94)forward
(106-108)forward
(130-135)forward
(154-159)forward
(184-203)forward
(220-222)forward
(232-234)forward
(247-253)forward
(272-275)MLP
(97-111)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (2)
forward
(75-85)forward
(119-129)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
torch_fake_quant_fp8_linear
(168-200)torch_fake_quant_fp8_linear
(204-214)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (7)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
run_test_transformed_gm
(68-138)tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
fp4_compatible
(33-34)fp8_compatible
(29-30)trtllm_ops_available
(37-38)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm
(198-284)tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ModelFactory
(15-207)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op
(183-206)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)
fp4_global_scale
(62-64)fp8_scale
(67-69)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(168-200)torch_fake_quant_fp8_linear
(204-214)torch_fake_quant_fp4_linear
(218-276)torch_fake_quant_fp4_linear
(280-289)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op
(183-206)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
build_custom_args_for_linear
(128-130)build_custom_args_for_linear
(294-301)build_custom_args_for_linear
(340-347)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(168-200)torch_fake_quant_fp8_linear
(204-214)torch_fake_quant_fp4_linear
(218-276)torch_fake_quant_fp4_linear
(280-289)
🔇 Additional comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
147-157
: Interface extension looks good; default no-ops are appropriate.The default kwargs/args hooks are clear and safe. Good separation of concerns for per-format overrides.
165-169
: Unified entry point via custom_op() is a solid choice.Binding to the .default overload makes matching in FX robust and consistent across sites.
239-243
: Custom op entry point for FP4 mirrors FP8—good consistency.Keeps transform plumbing uniform across formats.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1)
47-86
: TinyFP8Ref looks good; ensure device/dtype handling remains consistent.Precomputing
weight_fp8
/scales and using the reference op aligns with the fusion expectations. Buffers will migrate on.to("cuda")
. No functional issues spotted.
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
Outdated
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
Outdated
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
501-509
: Ensure consistent operator detection across all methods.This is the location where the detection inconsistency manifests - the function only checks
target_op()
but should also checkcustom_op()
to align with thecreate()
method's logic.See the fix suggested in the earlier comment for lines 94-105.
♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)
94-105
: Operator detection should be consistent between methods.The
create()
method now usescustom_op()
for FP4/FP8 detection (lines 98-99) butget_scales_and_type_from_node()
(lines 503-504) still only checkstarget_op()
. This inconsistency could cause issues.Apply this diff to
get_scales_and_type_from_node
(around lines 503-504):def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]: """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc).""" for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]: - if is_op(node, qtype.target_op()): + if is_op(node, qtype.target_op()) or is_op(node, qtype.custom_op()): return extract_scales_from_node( node, qtype.scale_names() ), qtype.__name__.lower().replace("quantizationimpl", "") return None, "simple"
285-302
: Fix copy-paste error in docstring.The docstring incorrectly mentions "torch_fake_quant_fp8_linear" when it should be "torch_fake_quant_fp4_linear".
Apply this diff:
@staticmethod def build_custom_args_for_linear( # renamed to reflect args scale_getattrs: Dict[str, Node], ) -> Tuple[object, ...]: """ - Build the *positional* tail for torch_fake_quant_fp8_linear: + Build the *positional* tail for torch_fake_quant_fp4_linear: (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list)) We pass bias=None to match the exported pattern: - torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], []) + torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0, args_3_1], [], []) """
🧹 Nitpick comments (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)
108-139
: Consider extracting common test setup pattern.The test has similar setup logic to
test_fp8_linear
, including input generation, weight quantization, and comparison assertions. Consider extracting common test utilities to reduce duplication and improve maintainability.Additionally, add error handling for when CUDA device is not available:
def _get_test_device(): if not torch.cuda.is_available(): pytest.skip("CUDA not available") return "cuda"
140-197
: Add explicit K alignment validation and improve error message.The assertion at line 156 correctly enforces the NVFP4 block size requirement, but could provide more helpful debugging information.
Apply this diff to improve the assertion:
- assert K % SCALING_VECTOR_SIZE == 0 + assert K % SCALING_VECTOR_SIZE == 0, ( + f"K dimension ({K}) must be a multiple of SCALING_VECTOR_SIZE ({SCALING_VECTOR_SIZE}) " + f"for NVFP4 quantization" + )Also consider adding a comment explaining why alpha is computed as
1/(s_in2*s_w2)
for future maintainability:+ # Alpha combines the input and weight per-tensor scales for the fused kernel alpha_fused = (1.0 / (s_in2 * s_w2)).to(torch.float32)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
264-284
: Good documentation, consider extracting common pattern.The contract documentation clearly explains the expected argument structure. Both FP8 and FP4 implementations share a similar pattern for building kwargs - consider extracting common logic to reduce duplication.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
(6 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
fp8_compatible
(29-30)fp4_compatible
(33-34)trtllm_ops_available
(37-38)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(160-191)torch_fake_quant_fp8_linear
(195-205)torch_fake_quant_fp4_linear
(209-267)torch_fake_quant_fp4_linear
(271-280)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
fp4_global_scale
(58-60)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op
(183-206)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
build_custom_args_for_linear
(128-130)build_custom_args_for_linear
(294-301)build_custom_args_for_linear
(340-347)tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
torch_fake_quant_fp8_linear
(160-191)torch_fake_quant_fp8_linear
(195-205)torch_fake_quant_fp4_linear
(209-267)torch_fake_quant_fp4_linear
(271-280)
🔇 Additional comments (7)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)
11-11
: LGTM! Clear constant naming improvement.The constant name change from
scaling_vector_size
toSCALING_VECTOR_SIZE
follows Python's uppercase convention for global constants.
50-50
: LGTM! Consistent use of the renamed constant.The update correctly reflects the renamed constant
SCALING_VECTOR_SIZE
.tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (5)
84-84
: LGTM! Clear documentation for unsupported feature.The comment clearly indicates that BMM NVFP4 is not yet supported.
143-154
: LGTM! Well-structured base class methods.The new methods provide a clean interface for custom quantization implementations to supply kernel-specific arguments.
161-165
: LGTM! Clear custom operation interface.The
custom_op()
method properly returns the FP8 fake quantization operator for unified kernel entry.
180-208
: Methods provide consistent FP8 argument handling.The implementation correctly builds both kwargs and args for the FP8 linear operation, maintaining consistency with the expected signature of
torch_fake_quant_fp8_linear
.
232-236
: LGTM! Consistent FP4 custom operation interface.The
custom_op()
method properly returns the FP4 fake quantization operator.
Signed-off-by: Frida Hou <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)
68-84
: Scale buffers are never populated from load_hook output_default_scales registers submodule buffers input_scale/weight_scale, but FP8/FP4 load_hook currently writes to weight_name + "_scale" in the state_dict, leaving module buffers at defaults (1.0/zeros). The fused/fake ops then read wrong scales.
I recommend updating the load hooks in quantization_utils.py to write into module buffer keys (e.g., f"{modprefix}.weight_scale" and, when derived, f"{modprefix}.input_scale"/f"{modprefix}.alpha"), or add a post-load hook here to transfer state_dict entries into those buffers before execution. Happy to provide a patch in quantization_utils.py (see below comment).
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
180-187
: FP8 load_hook should populate module buffers, not only aux keysCurrently writes weight_name + "_scale" and leaves submodule ".weight_scale" buffer at default. Populate buffer keys so the op reads correct scales.
def load_hook(state_dict, prefix, *args, weight_name): if weight_name in state_dict: weight = state_dict[weight_name] if weight.dtype != torch.float8_e4m3fn: scale = fp8_scale(state_dict[weight_name]) - state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn) - state_dict[weight_name + "_scale"] = scale + state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn) + mod_prefix = weight_name.rsplit(".", 1)[0] + # Ensure module buffers get correct values + state_dict[f"{mod_prefix}.weight_scale"] = scale + # input_scale often defaults to 1.0 here; set if missing + state_dict.setdefault(f"{mod_prefix}.input_scale", torch.tensor(1.0))
241-287
: FP4 load_hook should also set module buffersSimilar issue: ensure submodule buffers (input_scale/weight_scale/alpha) are written so execution uses the right scales.
def load_hook(state_dict, prefix, *args, weight_name): if weight_name in state_dict: input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale" alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha" + weight_scale_buf_name = weight_name.rsplit(".", 1)[0] + ".weight_scale" weight = state_dict[weight_name] @@ - state_dict[weight_name + "_scale"] = weight_scale + # Also populate module buffer key for execution + state_dict[weight_scale_buf_name] = weight_scale @@ - if ( + if ( weight_name + "_scale_2" in state_dict and weight_name + "_scale" in state_dict and input_scale_name in state_dict and float4_sf_dtype ): @@ - state_dict[weight_name + "_scale"] = ( + converted = ( torch.ops.trtllm.block_scale_interleave( weight_scale.view(torch.uint8).cpu().contiguous() ) .reshape(ori_shape) .view(float4_sf_dtype) .reshape(-1) ) + state_dict[weight_name + "_scale"] = converted + state_dict[weight_scale_buf_name] = converted
403-417
: Exclude-pattern helper uses linear-specific extractor for BMMFor BMM nodes, extract_param_names_from_lin_node is invalid. Guard by op type and derive module name from the get_attr weight when available.
def should_skip_quantization( node_or_name: Union[Node, str], excluded_patterns: list[str], ) -> bool: @@ - else: - if not (is_linear_op(node_or_name, include_quantization=False) or is_bmm_op(node_or_name)): - return True - param_name, _ = extract_param_names_from_lin_node(node_or_name) - modname, _, _ = param_name.rpartition(".") + else: + if is_linear_op(node_or_name, include_quantization=False): + param_name, _ = extract_param_names_from_lin_node(node_or_name) + modname, _, _ = param_name.rpartition(".") + elif is_bmm_op(node_or_name): + wt = node_or_name.args[1] + if getattr(wt, "op", None) == "get_attr": + modname, _, _ = wt.target.rpartition(".") + else: + # dynamic weight; no owning module — don't exclude + return False + else: + return True
♻️ Duplicate comments (5)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
300-318
: quantization_cb re-registers existing buffers => runtime errorCalling register_buffer on existing names raises; update in place when present and validate buffer existence. This mirrors prior feedback.
- for scale_name in self.scale_names(): - scales[scale_name] = submod.get_buffer(scale_name) + for scale_name in self.scale_names(): + try: + buf = submod.get_buffer(scale_name) + except Exception: + buf = None + if buf is None: + raise RuntimeError( + f"Expected buffer '{scale_name}' on module '{type(submod).__name__}'" + ) + scales[scale_name] = buf @@ - for k, v in sharded_scales.items(): - submod.register_buffer(k, v) + for k, v in sharded_scales.items(): + if k in submod._buffers: + setattr(submod, k, v) + else: + submod.register_buffer(k, v)tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)
212-247
: Mirror finalize/flag logic (prior feedback)This is the same suggestion previously raised for the BMM path; applies here as well.
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (3)
44-46
: Replace Python 3.10 union syntax with Optional for Py3.8 compatibility.Keep repo-wide Python 3.8 support.
- weights_scaling_factor_2: torch.Tensor | None = None, + weights_scaling_factor_2: Optional[torch.Tensor] = None,(Apply in both function signatures.)
Also applies to: 99-100
81-93
: Avoid boolean*uint8 mix, don’t shadow built-ins, keep inputs unmodified.Use logical AND with a boolean mask; rename ord/round.
- mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device) + mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8, device=device) @@ - ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8) + ordinal = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8) @@ - round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1) - fp4_val = (sign_bit * 0b1000 + ord + round).to(torch.uint8) + round_up = torch.any( + (weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) & mask.bool(), + dim=-1, + ) + fp4_val = (sign_bit * 0b1000 + ordinal + round_up).to(torch.uint8)
1-1
: Add NVIDIA SPDX/Apache-2.0 header (2025) at file top.Required by project guidelines for all source files.
Apply this diff at the very beginning of the file:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
🧹 Nitpick comments (18)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
1-1
: Add NVIDIA Apache-2.0 headerSource files require the header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
446-449
: Broaden resolver to also match fused ops and guard missing symbolsMake TP_SHARDING_RULES robust across fake and fused paths and import-time safe.
-TP_SHARDING_RULES = [ - (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear), FP8TPShardingInfo), - (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear), FP4TPShardingInfo), -] +TP_SHARDING_RULES = [ + ( + lambda n: is_op( + n, + ( + getattr(torch.ops.auto_deploy, "torch_fake_quant_fp8_linear", None), + getattr(torch.ops.auto_deploy, "torch_quant_fp8_linear", None), + ), + ), + FP8TPShardingInfo, + ), + ( + lambda n: is_op( + n, + ( + getattr(torch.ops.auto_deploy, "torch_fake_quant_fp4_linear", None), + getattr(torch.ops.auto_deploy, "torch_quant_fp4_linear", None), + ), + ), + FP4TPShardingInfo, + ), +]tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (3)
1-1
: Add NVIDIA Apache-2.0 headerRequired for non-test source files.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
136-142
: Use Python 3.8-compatible typingReplace X | None with Optional[X] and import Optional.
-from typing import Tuple +from typing import Optional, Tuple @@ -def _fp4_ref_repl_2( +def _fp4_ref_repl_2( x: torch.Tensor, w_fp4: torch.Tensor, - bias: torch.Tensor | None, + bias: Optional[torch.Tensor],
259-269
: Finalize graph after rewritesEliminate dead code and recompile for stability and downstream passes.
- num_matches = patterns.apply(gm.graph) + num_matches = patterns.apply(gm.graph) + if num_matches: + gm.graph.eliminate_dead_code() + gm.recompile()tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2)
1-1
: Add NVIDIA Apache-2.0 headerRequired for source files.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
174-209
: Return skipped=True when no matches; recompile on changesAlign semantics across transforms.
@@ class LinearQuantizationFromConfig(BaseTransform): - info = TransformInfo( - skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True - ) - return gm, info + if num_matches: + gm.graph.eliminate_dead_code() + gm.recompile() + return gm, TransformInfo( + skipped=(num_matches == 0), + num_matches=num_matches, + is_clean=False, + has_valid_shapes=True, + ) @@ class BMMQuantizationFromConfig(BaseTransform): - info = TransformInfo( - skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True - ) - return gm, info + if num_matches: + gm.graph.eliminate_dead_code() + gm.recompile() + return gm, TransformInfo( + skipped=(num_matches == 0), + num_matches=num_matches, + is_clean=False, + has_valid_shapes=True, + )Also applies to: 244-247
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)
1-1
: Add NVIDIA Apache-2.0 headerRequired for source files.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
438-446
: Return type annotation doesn’t reflect None possibilityFunction can return (None, "simple"); update typing or avoid None returns.
-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]: +from typing import Optional +def get_scales_and_type_from_node(node: Node) -> Tuple[Optional[Dict[str, Node]], str]:tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (6)
12-13
: Set explicit float dtype for FP4 tables.Avoid implicit int64 for e2m1_values; ensure math stays in FP.
-e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) -e2m1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6]) +e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32) +e2m1_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], dtype=torch.float32)
31-38
: Remove unused parameter from _dequant_weight_fp8 and update call site.out_features is not used. Simplify signature and callers.
-def _dequant_weight_fp8( - weight_fp8: torch.Tensor, - weight_scale: torch.Tensor, - out_features: int, - dtype: torch.dtype, -) -> torch.Tensor: - return weight_fp8.to(dtype) * weight_scale +def _dequant_weight_fp8( + weight_fp8: torch.Tensor, + weight_scale: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return weight_fp8.to(dtype) * weight_scaleAnd adjust usage below:
- weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype) + weight_deq = _dequant_weight_fp8(weight_quantized, s_w, in_dtype)
133-157
: Add basic shape assertions to dequant path.Fail-fast on malformed inputs to ease debugging.
def _dequantize_nvfp4( @@ ) -> torch.Tensor: device = quantized_t.device N, K = orig_shape + assert K % 16 == 0, "NVFP4 dequant requires K to be a multiple of 16." + assert quantized_t.shape[-2] == N and quantized_t.shape[-1] == K // 2, \ + f"quantized_t has shape {tuple(quantized_t.shape[-2:])}, expected {(N, K//2)}"
159-206
: FP8 fake op: minor robustness, ensure scales are on same device.Move scales to input.device to avoid device mismatch if provided as CPU tensors.
- s_in = _expect_single_scale(input_scale, "input_scale") - s_w = _expect_single_scale(weight_scale, "weight_scale") + s_in = _expect_single_scale(input_scale, "input_scale").to(input.device) + s_w = _expect_single_scale(weight_scale, "weight_scale").to(input.device)
208-281
: FP4 eager path: validate scale list lengths and devices.Guard early and pin scales to device.
- if len(weight_scale) < 2 or weight_scale[0] is None or weight_scale[1] is None: + if len(weight_scale) < 2 or weight_scale[0] is None or weight_scale[1] is None: raise ValueError( "NVFP4 needs weight_scale[0] (per-block vector) and weight_scale[1] (alpha)." ) - cutlass_qscale = weight_scale[0] - alpha = weight_scale[1] + cutlass_qscale = weight_scale[0].to(input.device) + alpha = weight_scale[1].to(input.device)
11-14
: Optional: treat FP tables/constants as UPPER_SNAKE_CASE.Matches repo style for constants.
-# FP4 tables (E2M1) -e2m1_bounds = ... -e2m1_values = ... +# FP4 tables (E2M1) +E2M1_BOUNDS = ... +E2M1_VALUES = ...(Propagate symbol rename locally.)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
142-151
: Defensive buffer loading to avoid hard failures.If any expected buffer is missing, skip fusion for this group instead of raising.
- for weight_key in keys_unfused: - key = weight_key.rsplit(".", 1)[0] - for scale_name in flat_scale_names: - buffer_name = key + "." + scale_name - scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name)) + try: + for weight_key in keys_unfused: + key = weight_key.rsplit(".", 1)[0] + for scale_name in flat_scale_names: + buffer_name = key + "." + scale_name + scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name)) + except (AttributeError, KeyError) as e: + ad_logger.warning(f"Missing quant buffers for {keys_unfused}, skipping fusion: {e}") + return
166-173
: Prefer graph.get_attr over create_node('get_attr') for buffers.Keeps FX graph construction consistent with parameter fetch above.
- scale_getattrs: Dict[str, Node] = { - name: gm.graph.create_node("get_attr", f"{key_fused}_{name}") - for name in flat_scale_names - } + scale_getattrs: Dict[str, Node] = { + name: gm.graph.get_attr(f"{key_fused}_{name}") + for name in flat_scale_names + }
270-291
: FP8 fuse_rule: scale aggregation via max is OK; ensure device alignment.Small tweak to avoid device mismatch when stacking scales.
- new_weight_scale = torch.max(torch.stack(weight_scale)) + ws = [s.to(weights[0].device) for s in weight_scale] + new_weight_scale = torch.max(torch.stack(ws))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (11)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
(7 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
(5 hunks)tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
(2 hunks)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
(5 hunks)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
(6 hunks)tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
(1 hunks)tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
(1 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
(7 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
🚧 Files skipped from review as they are similar to previous changes (3)
- tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
- tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
- tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
📚 Learning: 2025-08-27T16:22:10.642Z
Learnt from: Fridah-nv
PR: NVIDIA/TensorRT-LLM#7227
File: tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py:94-100
Timestamp: 2025-08-27T16:22:10.642Z
Learning: When there are inconsistent operator detection methods (like custom_op() vs target_op()), removing one method and standardizing on the other is often cleaner than supporting both methods simultaneously.
Applied to files:
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (6)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ModelFactory
(15-207)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface
(12-70)tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)
ADPatternMatcherPass
(96-102)register_ad_pattern
(134-217)apply
(99-102)tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
BaseTransform
(139-376)SharedConfig
(51-57)TransformInfo
(108-133)TransformRegistry
(379-407)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
is_op
(183-206)extract_param_names_from_lin_node
(149-170)get_op_overload_packet
(173-180)is_linear_op
(240-252)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface
(12-70)tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
cuda_memory_tracker
(10-26)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (10)
FakeFP8Linear
(259-275)forward
(74-94)forward
(106-108)forward
(130-135)forward
(154-159)forward
(184-203)forward
(220-222)forward
(232-234)forward
(247-253)forward
(272-275)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
FP8TPShardingInfo
(331-375)SplitDimension
(178-182)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_linear_op
(240-252)tensorrt_llm/_torch/modules/linear.py (1)
split_dim
(48-49)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7)
QuantizationImpl
(68-142)cutlass_fp4_scale_to_modelopt_fp4_scale
(43-55)modelopt_fp4_scale_to_cutlass_fp4_scale
(31-40)scale_names
(114-116)scale_names
(157-158)scale_names
(208-209)scale_names
(357-358)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (4)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
build_custom_args_for_linear
(126-128)build_custom_args_for_linear
(292-299)build_custom_args_for_linear
(338-345)tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (10)
build_custom_args_for_linear
(139-142)build_custom_args_for_linear
(165-177)build_custom_args_for_linear
(230-238)target_op
(104-106)target_op
(147-148)target_op
(199-200)target_op
(347-348)QuantizationImpl
(68-142)create
(72-101)should_skip_quantization
(403-416)tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
TransformRegistry
(379-407)register
(385-392)BaseTransform
(139-376)get
(395-397)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
is_linear_op
(240-252)is_bmm_op
(255-262)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
build_custom_args_for_linear
(126-128)build_custom_args_for_linear
(292-299)build_custom_args_for_linear
(338-345)
🔇 Additional comments (6)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (2)
82-93
: Add FP8 MLP test module: looks goodClass mirrors MLP correctly using FakeFP8Linear; forward shape/activation path matches expectations.
331-333
: Param additions LGTMIncluding FP8MLP in sharding tests and focusing pattern-detection on FP8 is fine for this PR scope.
Also applies to: 349-353
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
31-38
: dequant_weight_fp8 call sites verified; no updates neededThe only invocation of
_dequant_weight_fp8
is in
- tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py:186
weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)which exactly matches the updated parameter order and count. No other references were found across the codebase, so no further changes are required.
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
86-97
: check_same_children: implementation is sound.Converts users to a stable list before iteration and enforces unanimous type; good helper for safe fusion.
201-205
: Bias-less constraint is explicit; good.The bias==None guard simplifies fusion guarantees; matches non-quant path semantics.
316-337
: FP4 fuse_rule: concatenating weight scales assumes identical layout.Document/verify that per-block vectors are already in fused order; otherwise, derive via recorded split sizes.
Would you confirm that weight_scale vectors are laid out as [out_i blocks] per weight, matching cat(weights, dim=0)? If not, we should reorder before concat.
"Weight shape is not divisible for block size for block quantiation." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix user-facing typo in assert message.
“quantiation” → “quantization”.
- "Weight shape is not divisible for block size for block quantiation."
+ "Weight shape is not divisible for block size for block quantization."
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"Weight shape is not divisible for block size for block quantiation." | |
) | |
raise ValueError( | |
"Weight shape is not divisible for block size for block quantization." | |
) |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py around lines 57 to
58, fix the typo in the assertion message: change "quantiation" to
"quantization" so the user-facing message reads "Weight shape is not divisible
for block size for block quantization." Update only the string literal to
correct the spelling.
elif model_cls == FP8MLP: | ||
model = model_cls(num_features, num_features, bias=bias).to("cuda") | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid dtype casting FP8 weights during pattern-detection path
In _run_pattern_detection_job, FP8MLP currently falls into the generic else branch (Lines 210-220) which calls .to(dtype=torch.float16) and may upcast FakeFP8Linear weights, risking export/pattern failures. Mirror the special-case used in _run_job.
Apply:
@@
- if model_cls == GQA_Block:
+ if model_cls == GQA_Block:
...
- else:
+ elif model_cls == FP8MLP:
+ model = model_cls(num_features, num_features, bias=bias).to("cuda")
+ else:
model = model_cls(num_features, num_features, bias=bias).to(
device="cuda", dtype=torch.float16
)
Also applies to: 286-306
🤖 Prompt for AI Agents
In
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
around lines 117-119 (and similarly around lines 286-306), the FP8MLP case is
falling into the generic else that calls .to(dtype=torch.float16), which will
cast FakeFP8Linear weights and break pattern-detection/export; add a
special-case branch for model_cls == FP8MLP that mirrors the handling used in
_run_job by moving the model to CUDA without changing dtype (e.g., .to("cuda")
only) and ensure no .to(dtype=torch.float16) is applied to FP8MLP in both
locations so FakeFP8Linear weights remain untouched during pattern detection.
) -> torch.Tensor: | ||
""" | ||
Reference (eager) implementation for multiple quant formats via `format_type`. | ||
For FP4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, nvfp4 does dual scaling. how is that represented? is that s_w2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, is this op handling mxfp4 as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvfp4 does dual scaling. how is that represented? is that s_w2
Yes, the two scales for nvfp4 is q_per_block_scale_w
and s_w2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This op does not handle MXFP4 because it has quite different format than NVFP4,
MXFP4 does not have the second scale, its block size is 32, and FP8 scaler format is E8M0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rename this op to NVFP4
Thanks @Fridah-nv . Did you get a chance to test this on a sharded quantized model? maybe llama 70B Fp8 and/or llama4-fp8? |
stage: post_load_fusion | ||
fuse_fp4_gemms: | ||
stage: post_load_fusion | ||
fuse_fp8_gemms: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to specify fusion for different datatypes separately in the config? Maybe internally "fuse_gemms" can decide which fusions to perform? Alternatively, maybe consider similar approach in sharding (
sharding_dims: ['tp', 'ep', 'dp'] |
fuse_gemms:
stage: sharding
dtypes: ['fp16', 'fp8', 'fp4']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. Is there the needs to keep these quantization special classes at same abstract class level from different transforms?
For fusions, by default we should enable transforms for all dtypes because there can be mixed dtype gemms in the graph. We cannot dispatch to one of the fuse_gemms implementation at transform level. If we want to put the different dtype cases inside one Transform, it still need to iterate through all types, and we need some abstract class to inherent from. In this case, it's simpler to add QuantizationMixin class to Transform class.
I think sharding's case is different because TP,DP,EP sharding don't share much common code/utilities
return all(is_desired_child(u) for u in users) | ||
|
||
|
||
class QuantizationFusionMixin: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this class be AbstractBaseClass?
@@ -99,6 +114,8 @@ def _run_job( | |||
hidden_size=num_features, | |||
num_key_value_heads=num_key_value_heads, | |||
).to(device="cuda", dtype=torch.float16) | |||
elif model_cls == FP8MLP: | |||
model = model_cls(num_features, num_features, bias=bias).to("cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can drop dtype=torch.float16
entirely from the "non-quantized" models, and then, we wouldn't need a separate elif
for FP8MLP?
# (MLP, "torch_dist_all_reduce"), | ||
(FP8MLP, "torch_dist_all_reduce"), | ||
# (nn.Linear, "torch_dist_all_gather"), | ||
# (GQA_Block, "torch_dist_all_reduce"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the remaining model cases commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just testing, forgot to revert it...
Yes, I tested TP sharding for FP8 and NVFP4 e2e. I'm planning to test quantized MoE and BMM models for sanity check |
This PR includes these major changes:
quantize_linear_from_config
now quantize the linear nodes into torch fake quant ops liketorch.ops.auto_deploy.torch_fake_quant_fp8_linear
torch_quant_fp8_linear
(This is the op we previously map to, can consider rename these ops)Fusion passes now handled quantized ops with quantized fusion transforms (
fuse_fp4_gemms
,fuse_fp8_gemms
), each quantization format would inherit fromQuantizationFusionMixin
and implementfuse_rule
for its weight and scalers.Similarly, Sharding passes handle quantized op with specialized sharding info (
FP8TPShardingInfo
,FP4TPShardingInfo
) that contains specific implementation for each quantization format. Different types ofTPShardingInfo
is dispatched when we add new object withTPShardingInfo.from_node
.Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.