Skip to content

Conversation

Fridah-nv
Copy link
Collaborator

@Fridah-nv Fridah-nv commented Aug 25, 2025

This PR includes these major changes:

  1. quantization transformation is separated into two stages: pattern matcher and post_load_fusion.
  • In pattern matcher stage, quantize_linear_from_config now quantize the linear nodes into torch fake quant ops like torch.ops.auto_deploy.torch_fake_quant_fp8_linear
  • In post_load_fusion stage, reference ops are mapped to real quant implementation, e.g. torch_quant_fp8_linear (This is the op we previously map to, can consider rename these ops)
  • We introduce the fake quantize op stage to further standardize the graph for other transforms in the pipeline, and also represent same quantization format (e.g.'FP8') in different sources (e.g. 'modelopt', 'compressed_tensor'
  1. Fusion passes now handled quantized ops with quantized fusion transforms (fuse_fp4_gemms, fuse_fp8_gemms), each quantization format would inherit from QuantizationFusionMixin and implement fuse_rule for its weight and scalers.

  2. Similarly, Sharding passes handle quantized op with specialized sharding info (FP8TPShardingInfo , FP4TPShardingInfo) that contains specific implementation for each quantization format. Different types of TPShardingInfo is dispatched when we add new object with TPShardingInfo.from_node.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Fridah-nv <[email protected]>

minor

Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>

rename torch cumstom op

Signed-off-by: Frida Hou <[email protected]>
Signed-off-by: Frida Hou <[email protected]>

Update fusion to pass in args instead of kwrargs

Signed-off-by: Frida Hou <[email protected]>
Copy link
Contributor

coderabbitai bot commented Aug 25, 2025

📝 Walkthrough

Walkthrough

Splits config-driven quantization into separate linear and BMM transforms; adds FP8 and NVFP4 fake-quant custom ops and pattern-matcher fusion; introduces a quantization-aware fusion framework and sharding mixins (TPShardingInfo.from_node); adjusts fake-mode detection, quantization contracts, configs, and tests.

Changes

Cohort / File(s) Summary
Config updates
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Renames quantize_from_configquantize_linear_from_config; adds quantize_bmm_from_config; enables fuse_gemms; adds fuse_fp4_gemms, fuse_fp8_gemms, and fuse_quant.
Custom ops package
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Re-exports torch_quant via from .torch_quant import *.
Custom quant ops (FP8/NVFP4)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
New module implementing FP8 and NVFP4 fake-quant helpers, constants, helpers for scale/packing, and two custom ops (torch_fake_quant_fp8_linear, torch_fake_quant_fp4_linear) with eager/register_fake handlers; exports ops into quant op lists.
Pattern fusion transform
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
New FuseQuant transform registering AD pattern-matcher rewrites from reference fake-quant linear ops (FP8/FP4) to fused quant ops (bias/no-bias variants).
Fusion framework & GEMM fuses
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
Introduces QuantizationFusionMixin, check_same_children; adds FuseFP8Gemms, FuseFP4Gemms; reworks FuseGemms to avoid fusing quantized ops and to use the new fusion API.
Quantization transforms split
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
Splits QuantizationFromConfig into LinearQuantizationFromConfig (quantize_linear_from_config) and BMMQuantizationFromConfig (quantize_bmm_from_config); linear path inserts custom quant ops; BMM insertion returns success bool.
Sharding transform adjustments
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Replaces TPShardingInfo constructor usage with TPShardingInfo.from_node(...) call sites.
Pattern matcher: fake-mode detection
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
Adds ad_detect_fake_mode and binds it to torch._dynamo.utils.detect_fake_mode.
Quantization utilities
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
Adds build_custom_args_for_linear contract on QuantizationImpl; maps NVFP4 to FP4 impl (non-BMM); updates FP8/FP4 impls to return fake-quant target ops and custom-args; removes legacy shard/fuse helpers.
Sharding utilities & quant mixins
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Adds QuantizationShardingMixin, FP8TPShardingInfo, FP4TPShardingInfo, _shard_fp4_weight_scale, TP_SHARDING_RULES, and TPShardingInfo.from_node; _insert_sharded_matmul gains quantization_cb hook.
Tests: custom ops parity
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
Renamed scaling constant to SCALING_VECTOR_SIZE; adds tests comparing fused vs unified FP8 and NVFP4 linear ops.
Tests: quantization transform keys
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
Update config keys to quantize_linear_from_config and quantize_bmm_from_config.
Tests: GEMM fusion models
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
Replace FP8 test Linear with FakeFP8Linear and add fuse_fp8_gemms config entry.
Tests: import relocation
tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
Move _shard_fp4_weight_scale import to sharding_utils.
Tests: helper utilities
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
Add FP8_MAX and FakeFP8Linear test helper (duplicate definitions present).
Tests: quant-fusion unit tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
New tests and tiny reference modules (TinyFP8Ref, TinyFP4Ref) verifying fusion rewrites and numerical parity.
Tests: TP sharding
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
Add FP8MLP test model, integrate FP8 path into sharding tests, adjust world_size parameterization.
Tests: debug prints
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
Added debug prints to dump detected/expected sets in sharding pattern detection tests.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Config as default.yaml
  participant Optim as InferenceOptimizer
  participant PM as PatternMatcher
  participant QT as QuantizeTransforms
  participant FU as FusionTransforms
  participant FQ as FuseQuant

  Config->>Optim: load transforms (pattern_matcher, post_load_fusion)
  Optim->>PM: run quantize_linear_from_config
  PM->>QT: insert custom FP8/FP4 linear ops (custom_op tails)
  Optim->>PM: run quantize_bmm_from_config
  PM->>QT: insert BMM quant (if supported)
  Optim->>FU: post_load_fusion (fuse_gemms,fuse_fp8_gemms,fuse_fp4_gemms)
  FU->>FU: group by shared activation and insert fused quant GEMMs
  Optim->>FQ: post_load_fusion fuse_quant
  FQ->>FQ: pattern-match reference fake-quant -> fused quant ops
Loading
sequenceDiagram
  autonumber
  participant GM as GraphModule
  participant Mix as QuantizationFusionMixin
  participant Impl as FuseFP8/FuseFP4
  participant Attr as get_attr buffers

  GM->>Mix: locate linear nodes with same parent
  Mix->>Mix: check_same_children(all target_op?)
  alt all match
    Mix->>Impl: fuse_rule(weights, scales)
    Impl-->>Mix: fused_weight + fused_buffers
    Mix->>GM: register fused param and buffers
    GM->>Attr: create get_attr nodes for scales
    Mix->>GM: insert fused quant op with tail args
  else mismatch
    Mix-->>GM: skip group
  end
Loading
sequenceDiagram
  autonumber
  participant TP as TPShardingInfo.from_node
  participant Core as _insert_sharded_matmul
  participant QMixin as QuantizationShardingMixin.quantization_cb

  TP->>Core: apply(node) with quantization_cb
  Core->>QMixin: delegate quantization handling
  QMixin->>Core: gather buffers, compute sharded scales, register, add load hook
  Core-->>TP: sharded matmul wired with quantization-aware buffers
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

AutoDeploy

Suggested reviewers

  • suyoggupta
  • yuxianq
  • nv-guomingz

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1)

1-1: Add NVIDIA Apache-2.0 copyright header (required by repo guidelines).

All Python sources must carry the NVIDIA Apache-2.0 header. Please prepend it here.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)

1-1: Add NVIDIA Apache-2.0 header (tests are Python files too).

Please add the standard header at the top.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1)

265-283: Recompile GraphModule after pattern rewrites.

ADPatternMatcherPass.apply mutates gm.graph, but the Python code of the GraphModule isn’t regenerated automatically. Recompiling avoids stale code and ensures subsequent passes (and any eager execution) see the updated graph.

-        num_matches = patterns.apply(gm.graph)
+        num_matches = patterns.apply(gm.graph)
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3)

1-1: Add NVIDIA Apache-2.0 header.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

53-55: Fix potential IndexError when reading node.users.

user = list(node.users.keys())[0] will throw if there are zero users. Guard the access.

-        user = list(node.users.keys())[0]
-        if len(node.users) == 1 and is_quantized_op(user):
-            user.replace_all_uses_with(node)
+        if len(node.users) == 1:
+            sole_user = next(iter(node.users))
+            if is_quantized_op(sole_user):
+                sole_user.replace_all_uses_with(node)

174-211: Recompile and set skipped flag accurately in LinearQuantizationFromConfig.

  • Set skipped based on num_matches == 0.
  • Recompile after mutations to materialize changes in GraphModule code.
-        info = TransformInfo(
-            skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
-        )
-        return gm, info
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+        return gm, TransformInfo(
+            skipped=(num_matches == 0),
+            num_matches=num_matches,
+            is_clean=False,
+            has_valid_shapes=True,
+        )
🧹 Nitpick comments (21)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)

252-254: Make linear-op detection robust when custom ops are unavailable

Directly accessing torch.ops.auto_deploy.torch_fake_quant_fp{8,4}_linear can raise AttributeError on builds where these ops aren’t registered. Guard the additions so import-time doesn’t fail on older runtimes.

Apply this diff:

     if include_quantization:
         lin_ops.update(QUANT_LINEAR_OPS)
-        lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp8_linear])
-        lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp4_linear])
+        for _name in ("torch_fake_quant_fp8_linear", "torch_fake_quant_fp4_linear"):
+            _op = getattr(torch.ops.auto_deploy, _name, None)
+            if _op is not None:
+                lin_ops.add(_op)

1-1: Missing NVIDIA Apache-2.0 header

Per coding guidelines, prepend the NVIDIA Apache-2.0 copyright header (year 2025) to this file.

Do you want me to apply a repo-standard header template across all changed Python files?

tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)

37-73: Fake-mode detection: good coverage; minor resilience nits

The priority (TracingContext → active dispatch stack → inputs) looks correct and fixes fragility in upstream detection. Two small nits:

  • If inputs is None, pytree.tree_leaves(None) returns [None]; harmless but you can avoid a tiny branch cost.
  • If multiple FakeTensorModes are stacked, you’re already iterating the reversed stack (most recent first) — nice.

Optional micro-tidy:

-    flat_inputs = pytree.tree_leaves(inputs)
+    flat_inputs = () if inputs is None else pytree.tree_leaves(inputs)

75-77: Monkeypatch scope: limit to the pattern-matcher pass

Replacing torch._dynamo.utils.detect_fake_mode globally can have side effects in unrelated compilation paths. Patch only during ADPatternMatcherPass.apply() and restore afterward.

Apply this diff:

-# Replace the function used as a context manager
-torch._dynamo.utils.detect_fake_mode = ad_detect_fake_mode
+_ORIG_DETECT_FAKE_MODE = getattr(torch._dynamo.utils, "detect_fake_mode", None)

And update the pass:

 class ADPatternMatcherPass(PatternMatcherPass):
@@
     def apply(self, graph: Union[torch.fx.Graph, GraphModule]) -> int:
         """Apply pattern matcher with unsupported_input_tensor patch to bypass meta tensor check."""
-        with _patch_unsupported_input_tensor():
-            return super().apply(graph)
+        with _patch_unsupported_input_tensor():
+            # temporarily patch fake-mode detection
+            orig = getattr(torch._dynamo.utils, "detect_fake_mode", None)
+            torch._dynamo.utils.detect_fake_mode = ad_detect_fake_mode
+            try:
+                return super().apply(graph)
+            finally:
+                if orig is not None:
+                    torch._dynamo.utils.detect_fake_mode = orig

1-1: Missing NVIDIA Apache-2.0 header

Please prepend the NVIDIA Apache-2.0 header (2025) at the top of this file.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the standard NVIDIA Apache-2.0 header (2025).

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)

112-114: Set sharded params to requires_grad=False: ensure consistency with BMM path

Here we set requires_grad=False for linear shards (good for inference). In BMMShardingInfo.apply, sharded parameters are created with requires_grad=True. For consistency across inference transforms, prefer disabling grads there too.

If you agree, I can provide a follow-up diff in the BMM path.


391-444: FP8/FP4 TP sharding classes: interface looks sound

  • FP8: pass-through scales and no-op load hook make sense.
  • FP4: sharding both weight_scale and handling the *_scale state_dict suffix in the pre-hook is correct.

Two small nits:

  • Consider documenting required buffers explicitly in class docstrings (e.g., FP4 expects input_scale, weight_scale, alpha on the owning submodule).
  • Minor naming: sharded_uint8_weight_shape could just be weight_shape for consistency (docstring already clarifies it’s the sharded packed shape).

446-460: Resolver fallthrough handling

The except Exception: pass in _resolve_tp_cls_from_node hides real issues (e.g., NameError on missing ops). Log at debug at least, so we can diagnose resolver paths.

Minimal change:

-        except Exception:
-            pass
+        except Exception as e:  # pragma: no cover
+            ad_logger.debug(f"TP resolver predicate failed: {e}")

1-1: Missing NVIDIA Apache-2.0 header

Please add the required header (2025).

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (1)

1-1: Missing NVIDIA Apache-2.0 header in test file

Tests are also subject to the header requirement per guidelines.

I can sweep-add the header to all modified tests if desired.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

94-95: Order of post-load fusions: consider moving fuse_quant earlier.

fuse_quant rewrites fake-quant reference ops into fused kernels. Running it earlier in post-load-fusion can:

  • Increase the chance for downstream kernel-level fusions to act on the fused ops,
  • Reduce pattern mismatch after other structural fusions.

Recommendation: place fuse_quant ahead of GEMM/RMSNorm fusions unless there is a known dependency the other way around.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)

12-16: Remove duplicate/unused constants; use a single SCALING_VECTOR_SIZE.

  • FORMAT_FP8 and FORMAT_NVFP4 are unused here.
  • Both scaling_vector_size and SCALING_VECTOR_SIZE represent the same value (16). Keep one for consistency.
- scaling_vector_size = 16
-FORMAT_FP8 = 0
-FORMAT_NVFP4 = 1
-
-SCALING_VECTOR_SIZE = 16  # NVFP4 block size along K
+SCALING_VECTOR_SIZE = 16  # NVFP4 block size along K

And replace uses of scaling_vector_size below:

-    weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
-        weight, weight_scale_2, scaling_vector_size, False
-    )
+    weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
+        weight, weight_scale_2, SCALING_VECTOR_SIZE, False
+    )

156-201: LGTM with a minor suggestion for negative-case coverage.

This validates CUTLASS-scale + alpha wiring across dtypes with appropriate tolerances. Consider adding a negative test asserting failure when K is not a multiple of SCALING_VECTOR_SIZE to lock the contract.

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1)

30-46: Be explicit with op overloads for consistency.

Elsewhere you used .default for FP8 pattern calls; mirror that for replacements to avoid ambiguity.

-return torch.ops.auto_deploy.torch_quant_fp8_linear(
+return torch.ops.auto_deploy.torch_quant_fp8_linear.default(
     x, w_fp8, None, input_scale=input_scale, weight_scale=weight_scale)
...
-return torch.ops.auto_deploy.torch_quant_fp8_linear(
+return torch.ops.auto_deploy.torch_quant_fp8_linear.default(
     x, w_fp8, bias, input_scale=input_scale, weight_scale=weight_scale)

And similarly for FP4:

-return torch.ops.auto_deploy.torch_quant_fp4_linear(
+return torch.ops.auto_deploy.torch_quant_fp4_linear.default(
     x, w_fp4, bias=None, input_scale=input_scale, weight_scale=weight_scale, alpha=alpha)
...
-return torch.ops.auto_deploy.torch_quant_fp4_linear(
+return torch.ops.auto_deploy.torch_quant_fp4_linear.default(
     x, w_fp4, bias=bias, input_scale=input_scale, weight_scale=weight_scale, alpha=alpha)

Also applies to: 67-84

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

86-171: Consider deferring graph recompilation to the transform level (avoid per-node overhead).

_insert_quantized_bmm and _insert_quantized_linear appropriately avoid recompiling per node. The transform-level recompile suggested above keeps performance in check while ensuring correctness. No changes needed here beyond the transform finalize steps.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)

30-33: Consider consolidating FORMAT_ constants with existing imports*

The TODO comment indicates these format constants should be imported from a common location. Since torch_quant.py also defines the same constants (FORMAT_FP8 = 0, FORMAT_NVFP4 = 1), consider importing them from there to maintain a single source of truth.

-# TODO: put the ENUMs in the same place and import it
-FORMAT_FP8 = 0
-FORMAT_NVFP4 = 1
+from ..custom_ops.torch_quant import FORMAT_FP8, FORMAT_NVFP4

143-154: Consider adding type hints for better clarity

The new methods would benefit from more specific type hints for the object return type to improve type safety and IDE support.

-def build_custom_kwargs_for_linear(
-    scale_getattrs: Dict[str, Node],
-) -> Dict[str, object]:
+def build_custom_kwargs_for_linear(
+    scale_getattrs: Dict[str, Node],
+) -> Dict[str, Union[List[Node], List]]:

194-210: Inconsistent comment: torch_fake_quant_fp8_linear vs FP8

The docstring mentions "torch_fake_quant_fp8_linear" but this is the FP8QuantizationImpl class. Also, the example pattern in the comment doesn't match the actual implementation which uses Node objects, not raw arguments.

    def build_custom_args_for_linear(  # renamed to reflect args
        scale_getattrs: Dict[str, Node],
    ) -> Tuple[object, ...]:
        """
-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for FP8 quantized linear:
            (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list))
-
-        We pass bias=None to match the exported pattern:
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
        """
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)

255-255: Consider using modulo check with error message

The assertion could provide more context about the actual value of K when it fails.

-assert K % 16 == 0, "NVFP4 requires K to be a multiple of 16"
+assert K % 16 == 0, f"NVFP4 requires K to be a multiple of 16, got K={K}"
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)

214-216: Ensure consistent node filtering

The condition checks for partial(is_op, ops=self.target_op) but the initial collection uses is_op(node, self.target_op). While functionally equivalent, consider using the same pattern for clarity.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 788fc62 and d14d5f4.

📒 Files selected for processing (13)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA Apache-2.0 copyright header with current year to all source files

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧬 Code graph analysis (9)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (166-198)
  • torch_fake_quant_fp8_linear (202-212)
  • torch_fake_quant_fp4_linear (216-274)
  • torch_fake_quant_fp4_linear (278-287)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (3)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)
  • ADPatternMatcherPass (104-110)
  • register_ad_pattern (142-225)
  • apply (107-110)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (3)
  • BaseTransform (139-376)
  • SharedConfig (51-57)
  • TransformRegistry (379-407)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp8_compatible (29-30)
  • fp4_compatible (33-34)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (166-198)
  • torch_fake_quant_fp8_linear (202-212)
  • torch_fake_quant_fp4_linear (216-274)
  • torch_fake_quant_fp4_linear (278-287)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • fp4_global_scale (62-64)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
  • TPShardingInfo (221-264)
  • from_node (229-234)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (47-59)
  • custom_op (162-164)
  • custom_op (236-238)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (5)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
  • is_op (183-206)
  • extract_param_names_from_lin_node (149-170)
  • get_op_overload_packet (173-180)
  • is_linear_op (240-254)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
  • build_custom_args_for_linear (150-153)
  • build_custom_args_for_linear (195-210)
  • build_custom_args_for_linear (291-306)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)
  • register (385-392)
  • BaseTransform (139-376)
tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
  • cuda_memory_tracker (10-26)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (8)
  • build_custom_args_for_linear (150-153)
  • build_custom_args_for_linear (195-210)
  • build_custom_args_for_linear (291-306)
  • custom_op (162-164)
  • custom_op (236-238)
  • QuantizationImpl (72-153)
  • create (76-105)
  • should_skip_quantization (471-484)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • TransformRegistry (379-407)
  • register (385-392)
  • BaseTransform (139-376)
  • get (395-397)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • is_linear_op (240-254)
  • is_bmm_op (257-264)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (166-198)
  • torch_fake_quant_fp8_linear (202-212)
  • torch_fake_quant_fp4_linear (216-274)
  • torch_fake_quant_fp4_linear (278-287)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7)
  • QuantizationImpl (72-153)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (47-59)
  • modelopt_fp4_scale_to_cutlass_fp4_scale (35-44)
  • scale_names (118-120)
  • scale_names (173-174)
  • scale_names (246-247)
  • scale_names (425-426)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (48-49)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (21)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)

18-18: PyTorch version guard for TracingContext import

from torch._guards import TracingContext is internal and version-sensitive. If CI still runs on older 2.x, this import will fail at import-time.

Consider guarding the import:

-from torch._guards import TracingContext
+try:
+    from torch._guards import TracingContext
+except Exception:  # pragma: no cover
+    TracingContext = None  # Fallback; ad_detect_fake_mode should branch if None

And in ad_detect_fake_mode, check if TracingContext and (context := TracingContext.try_get()): ...

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)

106-114: Good migration to node-aware TPShardingInfo factory

Switching to TPShardingInfo.from_node(n, …) enables quantization-aware subclassing. This looks correct and makes future extensions easier.


308-316: Quantization-aware two-way sharding: ensure fake-quant linears resolve to the right subclass

This path relies on TPShardingInfo.from_node(n, …) correctly detecting FP8/FP4 nodes. Currently, the resolver in sharding_utils.py only checks fused ops (torch_quant_linear_fp{8,4}). If the graph contains fake-quant ops (torch_fake_quant_fp{8,4}_linear), it will fall back to the base TP class and skip scale sharding.

I’ve proposed a fix in sharding_utils.py to include fake-quant ops in TP_SHARDING_RULES. See that comment for a concrete diff.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

65-68: Extension point for quantization-aware sharding is well designed

Adding quantization_cb to _insert_sharded_matmul is a clean way to decouple quantization specifics from the core sharding flow.


148-158: Quantization callback invocation: good integration point

The call to quantization_cb after weight/bias sharding ensures scales and load hooks are set while shapes are known. This ordering is correct.


378-389: _shard_fp4_weight_scale: shape reconstruction math—request a quick validation

The reconstruction of the original weight shape assumes:

  • weight_shape_original[dim] *= world_size
  • last dim doubled (* 2) for unpacking FP4

If the shard is along columns (dim=1), doubling the last dimension aligns with your packed uint8 scheme. Please verify this with a shape example in tests (e.g., N×K_packed with K_packed=K/2) to avoid off-by-one on block edges.

I can add a focused unit test exercising dim=0 vs dim=1 with uneven multiples of 128/16 to ensure correct slicing of scales.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2)

76-79: Config key rename LGTM

Renaming to quantize_linear_from_config aligns with the split linear/BMM transforms. Looks good.


158-161: BMM quantization config key rename LGTM

quantize_bmm_from_config key matches the new transform; no issues spotted.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (2)

80-85: Re-enabling GEMM fusions can increase memory pressure; consider guardrails.

fuse_gemms, fuse_fp4_gemms, and fuse_fp8_gemms were previously disabled due to OOM risk. If we re-enable them by default, consider:

  • A config toggle to disable at runtime,
  • Conditioning on batch/hidden sizes,
  • Or staging behind a perf profile flag.

If recent CI runs haven’t covered large models, please schedule one to confirm memory headroom with these fusions enabled.


48-51: Transforms split verified – registry and config updated

  • Registered in quantization.py:
    @TransformRegistry.register("quantize_linear_from_config") at line 174
    @TransformRegistry.register("quantize_bmm_from_config") at line 213
  • Config example in config/default.yaml:
    quantize_linear_from_config at line 48
    quantize_bmm_from_config at line 50
    • No remaining references to the old quantize_from_config key
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)

114-143: LGTM: FP8 fused vs. unified parity test is tight and representative.

Good coverage with/without bias, explicit scale wiring ([in_s], [w_s]), and strict tolerances.

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

80-84: LGTM: switch to unified custom op tail args is correct.

Appending positional tail produced by build_custom_args_for_linear aligns with the new fake-quant op contracts.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

103-104: Good improvement: Using custom_op() for unified quantization detection

The change from target_op() to custom_op() standardizes the node matching logic for quantized operations, providing a cleaner separation between target operations and custom kernel entry points.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)

165-198: Well-structured FP8 quantization implementation

The reference implementation is clear and well-documented, with proper error handling for dtype mismatches and missing scales.


215-274: Complex but well-documented FP4 quantization logic

The NVFP4 implementation correctly handles the complex per-block quantization scheme with proper scale vector handling and shape transformations. The comments effectively explain the multi-step process.

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (6)

86-97: Clean utility function for verifying homogeneous children

The check_same_children function provides a robust way to ensure all child nodes match the expected type before fusion, preventing mixed-precision issues.


99-131: Excellent abstraction with the QuantizationFusionMixin

The mixin pattern effectively captures the common fusion logic while allowing subclasses to customize their specific quantization formats. The clear documentation of required attributes and methods makes this easy to extend.


153-157: Good error handling with informative logging

The try-except block properly catches NotImplementedError and logs a warning with context about which operations couldn't be fused.


282-283: Add dtype validation for weight fusion

Good assertion to ensure only FP8 quantized weights are being fused. This prevents unexpected behavior with mixed precision.


266-311: Well-implemented FP8 fusion with proper scaling

The FP8 fusion correctly handles weight scale recalculation by finding the maximum scale and re-quantizing accordingly. The implementation properly preserves numerical accuracy while enabling fusion.


313-356: FP4 fusion correctly handles per-block scales

The FP4 implementation properly concatenates the per-block scale vectors along with the weights, maintaining the quantization structure needed for the fused operation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (4)

1-1: Add NVIDIA SPDX+Apache-2.0 header (license compliance).

This source file lacks the required header.

Apply this diff at the very top:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

475-489: Replace built-in generics list[str] with typing.List[str] for Python 3.8 support.

The codebase targets Python 3.8+, which doesn’t support PEP 585 generics on built-ins.

Apply this diff:

-def should_skip_quantization(
-    node_or_name: Union[Node, str],
-    excluded_patterns: list[str],
-) -> bool:
+def should_skip_quantization(
+    node_or_name: Union[Node, str],
+    excluded_patterns: List[str],
+) -> bool:

491-518: Fix type annotations for Python 3.8; also widen return type to Optional.

Use List[str] instead of list[str]; get_scales_and_type_from_node can return None for scales.

Apply this diff:

-def extract_scales_from_node(node: Node, scale_names: list[str]) -> Dict[str, Optional[Node]]:
+def extract_scales_from_node(node: Node, scale_names: List[str]) -> Dict[str, Optional[Node]]:
@@
-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
+def get_scales_and_type_from_node(node: Node) -> Tuple[Optional[Dict[str, Node]], str]:

312-359: Ensure consistent input_scale and alpha semantics across both load paths

ModelOpt’s branch currently computes

alpha = 1 / (s_w2 * s_in)

and leaves input_scale as s_in, whereas the HF branch uses

alpha = s_w2 * s_in  
input_scale = 1 / s_in

The custom op expects input_scale to always represent 1/x and alpha to be s_in * s_w2. Please update the ModelOpt path accordingly:

• File: tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
Static method load_hook, around lines 320–335

Proposed diff:

     # ModelOpt quantized graph path
     if weight.dtype != torch.uint8:
         …
-        state_dict[alpha_name] = 1 / (weight_scale_2 * state_dict[input_scale_name])
+        state_dict[alpha_name] = weight_scale_2 * state_dict[input_scale_name]
+        state_dict[input_scale_name] = 1 / state_dict[input_scale_name]

If there are consumers relying on the old convention, consider gating this change under a version/compatibility flag. Let me know if you’d like assistance wiring that up.

♻️ Duplicate comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)

1-1: Add NVIDIA SPDX+Apache-2.0 header (license compliance).

All Python source files must start with the NVIDIA copyright header. Please add it before any imports.

Apply this diff at the top of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

52-54: Use typing.Optional instead of PEP 604 unions for Python 3.8 compatibility.

The “A | B” syntax requires Python 3.10+. Project guideline targets Python 3.8+.

Apply this diff:

 def _nvfp4_get_weights_scaling_factor(
     input: torch.Tensor,
     block_size: int,
-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,
     keep_high_precision: bool = False,
 ):
 ...
 def _quantize_nvfp4(
     input: torch.Tensor,
     block_size: int,
-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,
 ):

Also applies to: 107-108

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

294-304: Docstring mentions fp8 operator name in FP4 section; correct to fp4.

Copy/paste error; use torch_fake_quant_fp4_linear in the docstring and example.

Apply this diff:

-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for torch_fake_quant_fp4_linear:
@@
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
+        torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
🧹 Nitpick comments (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)

39-46: Remove unused parameter from _dequant_weight_fp8 and its call site.

out_features is not used; simplify signature and call.

Apply this diff:

 def _dequant_weight_fp8(
     weight_fp8: torch.Tensor,
     weight_scale: torch.Tensor,
-    out_features: int,
     dtype: torch.dtype,
 ) -> torch.Tensor:
     return weight_fp8.to(dtype) * weight_scale
-    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)
+    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, in_dtype)

Also applies to: 195-195


109-117: Docstring return mismatch (function returns 2 values, docstring says 3).

Update the docstring to reflect the actual 2-tuple return: (packed_weight, q_per_block_scale).

Apply this diff:

-    Returns:
-    tuple: Contains quantized data, quantized per block scaling factor, and per block scaling factor.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: (packed_weight, q_per_block_scale)

88-101: Clarify tie-breaking mask logic for FP4 rounding.

Multiplying a boolean equality by a uint8 mask works by accident. Prefer explicit boolean logic for readability and safety.

Apply this diff for clarity:

-    # Define mask to perform rounding
-    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device)
+    # Define boolean tie mask: True on odd indices to round up ties
+    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8, device=device).bool()
 ...
-    round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1)
+    round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) & mask, dim=-1)

217-276: NVFP4 scale semantics: input_scale/alpha naming vs. usage is inconsistent with helper contract.

The code treats input_scale as the inverse scale (inv_x) and expects alpha = s_in2 * s_w2. Ensure upstream load_hooks and defaults produce these exact semantics, or normalize inside this op to avoid subtle bugs across weight-load paths.

Would you like me to provide a normalization shim at the start of this op that accepts either representation and converts to (inv_x, alpha = s_in2*s_w2) consistently?

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

30-33: Avoid duplicating FORMAT_ enums; centralize and import.*

These enums also exist in custom_ops.torch_quant. Duplicates risk divergence.

Consider moving FORMAT_FP8/NVFP4 to a single shared module (e.g., tensorrt_llm/_torch/auto_deploy/common/quant_formats.py) and import from there.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)

39-56: Avoid tracking gradients when reassigning quantized weights/bias in tests.

Wrap parameter/buffer reassignments in torch.no_grad() to be explicit and avoid autograd hooks in some envs.

Apply this diff:

     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        device = self.weight.device
-        weight_scale = torch.max(torch.abs(self.weight)).to(torch.float).to(device) / FP8_MAX
-        self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn))
-        self.register_buffer(
-            "input_scale", torch.tensor(1.0, device=self.weight.device, dtype=torch.float)
-        )
-        self.register_buffer("weight_scale", weight_scale)
-        if self.bias is not None:
-            self.bias = nn.Parameter(self.bias.to(torch.half))
+        device = self.weight.device
+        with torch.no_grad():
+            weight_scale = torch.max(torch.abs(self.weight)).to(torch.float).to(device) / FP8_MAX
+            self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn))
+            if self.bias is not None:
+                self.bias = nn.Parameter(self.bias.to(torch.half))
+        self.register_buffer(
+            "input_scale", torch.tensor(1.0, device=self.weight.device, dtype=torch.float)
+        )
+        self.register_buffer("weight_scale", weight_scale)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

80-86: New post-load fusion switches: fuse_gemms, fuse_fp4_gemms, fuse_fp8_gemms, and fuse_quant.

Good separation of generic and quantized GEMM fusions. Ensure docs mention these toggles for users.

Would you like me to draft a short README snippet describing these flags and the expected effects?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d14d5f4 and 7d1f402.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧠 Learnings (2)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧬 Code graph analysis (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (47-59)
  • custom_op (166-168)
  • custom_op (240-242)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
  • torch_fake_quant_fp4_linear (218-276)
  • torch_fake_quant_fp4_linear (280-289)
🔇 Additional comments (4)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (3)

36-36: LGTM: stable FP8 constant for tests.

Using torch.finfo for FP8 max is appropriate and future-proof.


96-99: LGTM: switching target linear layers to FakeFP8Linear for FP8 scenarios.

This isolates the fusion test from HW-dependent kernels and exercises the FP8 fuse path.

Also applies to: 111-114, 135-137, 159-162, 191-193


286-288: LGTM: enabling fuse_fp8_gemms in test optimizer config.

Matches the intent of exercising the post_load_fusion path for FP8.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

48-51: Rename quantize_from_config → quantize_linear_from_config and add quantize_bmm_from_config: config aligns with code.

Configuration looks consistent with the split transforms; no action needed.

Signed-off-by: Frida Hou <[email protected]>

add unit test for quant fusion and quant sharding

Signed-off-by: Frida Hou <[email protected]>
@Fridah-nv Fridah-nv force-pushed the user/fridah/inherit-quant2 branch from 7d1f402 to 40ef068 Compare August 26, 2025 23:48
Signed-off-by: Frida Hou <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

1-1: Add NVIDIA 2025 copyright header (per repo policy).

All source files must start with the NVIDIA header for the current year. Please prepend the SPDX header to comply.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0

475-479: Python 3.8 compatibility: replace PEP 585 builtins with typing generics.

The repo targets Python 3.8+. Using list[str] requires 3.9+. Switch to List[str] (already imported).

-def should_skip_quantization(
-    node_or_name: Union[Node, str],
-    excluded_patterns: list[str],
-) -> bool:
+def should_skip_quantization(
+    node_or_name: Union[Node, str],
+    excluded_patterns: List[str],
+) -> bool:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

110-121: FP8MLP path casts FP8 weights to float16 during pattern detection.

In _run_pattern_detection_job the generic else branch uses .to(device="cuda", dtype=torch.float16) for all non-GQA models. When model_cls == FP8MLP, this will upcast the FP8-quantized weights to fp16 and break the FP8 reference op (it expects float8_e4m3fn). This can cause export-time failures or incorrect pattern detection.

Fix by adding a dedicated FP8MLP branch that does not cast dtype:

@@ def _run_pattern_detection_job(...):
-    else:
-        model = model_cls(num_features, num_features, bias=bias).to(
-            device="cuda", dtype=torch.float16
-        )
+    elif model_cls == FP8MLP:
+        # Keep FP8 quantized params; don't cast dtype.
+        model = model_cls(num_features, num_features, bias=bias).to("cuda")
+    else:
+        model = model_cls(num_features, num_features, bias=bias).to(
+            device="cuda", dtype=torch.float16
+        )
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)

1-5: Add mandatory NVIDIA copyright header (2025).

Per coding guidelines, every Python source must begin with the NVIDIA copyright header for the current year.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

510-517: Update detector to also match custom_op() (align with create()).

get_scales_and_type_from_node only matches target_op(); with the new custom_op()-based path, scales won’t be found from nodes calling the custom kernels.

-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
+def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
     """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc)."""
-    for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
-        if is_op(node, qtype.target_op()):
+    for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
+        if is_op(node, [qtype.custom_op(), qtype.target_op()]):
             return extract_scales_from_node(
                 node, qtype.scale_names()
             ), qtype.__name__.lower().replace("quantizationimpl", "")
     return None, "simple"

294-311: Fix copy-paste error in FP4 docstring (mentions FP8 instead of FP4).

The docstring for FP4’s build_custom_args_for_linear references torch_fake_quant_fp8_linear. It should reference the FP4 variant.

-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for torch_fake_quant_fp4_linear:
@@
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
+        torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
🧹 Nitpick comments (17)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (6)

30-33: Stabilize and document format constants; mark as Final.

These int flags are fine, but make the intent explicit to type checkers and future readers. Also, the TODO suggests centralizing—good follow-up, but marking them Final is a quick win now.

Apply this diff here:

-# TODO: put the ENUMs in the same place and import it
-FORMAT_FP8 = 0
-FORMAT_NVFP4 = 1
+# TODO: centralize format constants in a shared module and import from there
+from typing import Final  # type: ignore[unused-ignore]  (ok in local scope if not hoisted)
+FORMAT_FP8: Final[int] = 0
+FORMAT_NVFP4: Final[int] = 1

Additionally (outside this hunk), add Final to the typing import on Line 2:

from typing import Dict, List, Optional, Tuple, Union, Final

88-89: Return a clear error for unsupported NVFP4 BMM rather than None.

Returning None here can lead to confusing downstream AttributeErrors. Raise a targeted NotImplementedError to fail fast with context.

-                    "NVFP4": None,  # BMM NVFP4 is not supported yet
+                    "NVFP4": None,  # BMM NVFP4 is not supported yet
                 }
-            return quantization_impl_map[quant_type_or_node]
+            impl = quantization_impl_map[quant_type_or_node]
+            if quant_type_or_node == "NVFP4" and impl is None:
+                raise NotImplementedError("BMM NVFP4 is not supported yet")
+            return impl

101-104: Broaden node detection to support both new custom_op() and legacy target_op().

Only checking custom_op() risks missing older graphs that still use target_op(). Use an OR to be backward compatible without sacrificing the new path.

-        ]:
-            if is_op(quant_type_or_node, q.custom_op()):
-                return q
+        ]:
+            if is_op(quant_type_or_node, [q.custom_op(), q.target_op()]):
+                return q

184-197: Remove commented-out format_type to avoid confusion.

Since FP8/FP4 have distinct custom kernels now, the commented format_type is misleading.

         return dict(
             input_scale=[scale_getattrs["input_scale"]],
             weight_scale=[scale_getattrs["weight_scale"]],
             input_zp=[],
             weight_zp=[],
-            # format_type=FORMAT_FP8,
         )

198-215: Docstring example: clarify bias handling or drop the example.

Doc says “We pass bias=None,” but the example shows positional args without explicitly indicating None. Either annotate which arg is None or omit the example to prevent misreads.


271-293: Remove commented-out format_type here as well (mirror FP8 change).

Keeps the contract focused on the actual inputs the kernels consume now.

         return dict(
             input_scale=[scale_getattrs["input_scale"]],
             weight_scale=[scale_getattrs["weight_scale"], scale_getattrs["alpha"]],
             input_zp=[],
             weight_zp=[],
-            # format_type=FORMAT_NVFP4,
         )
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)

12-13: Ensure custom quant ops are registered for FP8 tests.

FP8MLP relies on FakeFP8Linear which calls torch.ops.auto_deploy.torch_fake_quant_fp8_linear. Importing the custom ops in this test file avoids order-dependent failures when running tests selectively.

Add near the top (close to other imports):

+import tensorrt_llm._torch.auto_deploy.custom_ops  # noqa: F401

81-92: Minor: add a short docstring and keep dtype consistent.

FP8MLP is a test-only module. Adding a one-liner docstring clarifies intent. Also consider keeping the module in half precision where appropriate for inputs, but avoid mutating FP8 params (you already do in _run_job). No code change required beyond docstring.

Example:

 class FP8MLP(nn.Module):
+    """Two-layer MLP using FakeFP8Linear modules to exercise FP8 sharding paths."""
     def __init__(self, in_features, out_features, bias=False):
         super().__init__()

304-315: Param space coverage note (optional).

Pattern-detection is only checked for world_size=[8]. If CI time permits, adding a smaller value (e.g., 2 or 4) can catch shape corner cases without significant overhead.


1-1: Missing NVIDIA copyright header.

Per the coding guidelines, prepend the NVIDIA copyright header (2025) to this file.

Add at the very top:

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (4)

1-1: Add NVIDIA copyright header.

All source files (.py included) must carry the NVIDIA copyright header.

Insert above line 1:

+ # Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

16-24: Initialize ModelFactory base class and return quant config from the stub.

DummyFactory inherits ModelFactory but doesn’t call super().__init__, leaving base fields uninitialized. Some optimizer paths read attributes like skip_loading_weights or call get_quant_config(). Make the stub explicit and safe.

Apply this diff:

 class DummyFactory(ModelFactory):
     def __init__(self, quant_config=None):
-        self._quant_config = quant_config or {}
+        super().__init__(model="", skip_loading_weights=True)
+        self._quant_config = dict(quant_config or {})
 
+    def get_quant_config(self) -> dict:
+        return dict(self._quant_config)
+
     def _build_model(self, device: str):
-        return
+        return None
 
     def _load_checkpoint(self, model, device):
-        return
+        return None

88-130: Clarify FP4 scale semantics (‘s_in2’ vs ‘inv’).

The custom op docstring indicates input_scale[0] = s_in2 and weight_scale[1] = alpha = s_in2 * s_w2. Here you pass input_scale_2 = s_in2 and alpha = 1/(s_in2 * s_w2). That’s the reciprocal of the docstring’s alpha, but matches the internal variable names (inv_x, etc.) in the reference implementation.

Consider adding a short comment to disambiguate “inv” vs “s2” conventions, or rename buffers to reflect “inv_” semantics if that’s the intended contract. This reduces future confusion should the op’s docs or code be refactored.

Would you like me to propose a consistent naming/comment patch once you confirm which convention we standardize on?


27-35: Assertion helpers correctly detect rewrite; minor consistency nit.

_has_fused_linear_fp8 checks .default overload for the reference op, while _has_fused_linear_fp4 checks the packet (no .default). Both work due to is_op handling OpOverloadPacket, but using .default in both places improves uniformity.

Optional tweak:

-    found_ref = any(
-        is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear) for n in gm.graph.nodes
-    )
+    found_ref = any(
+        is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear.default) for n in gm.graph.nodes
+    )
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (3)

259-261: Add a short docstring for FakeFP8Linear (public test helper).

This class is imported by multiple tests; a concise docstring clarifies its contract and limits (e.g., test-only, expects FP8 path).

Apply this diff:

 class FakeFP8Linear(nn.Linear):
     def __init__(self, *args, **kwargs):
+        """
+        Test-only Linear that stores FP8-quantized weights and invokes the FP8 fake-quant op.
+        Not intended for training/autograd; biases are cast to the input dtype at runtime.
+        """
         super().__init__(*args, **kwargs)

263-265: Consider keeping float32 weights and storing FP8 weights in a buffer to reduce surprise for code expecting nn.Linear.weight to be float.

Overwriting self.weight with FP8 may confuse utilities or checkpoints that assume a float weight tensor. A low-impact alternative: keep self.weight as-is and store weight_fp8 as a buffer.

Optional refactor:

-        self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn))
+        # Keep original float weight; store FP8 copy for the fake-quant op.
+        self.register_buffer("weight_fp8", (self.weight.detach() / weight_scale).to(torch.float8_e4m3fn))

and in forward:

-        return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
-            x, self.weight, bias, [self.input_scale], [self.weight_scale], [], []
-        )
+        return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
+            x, self.weight_fp8, bias, [self.input_scale], [self.weight_scale], [], []
+        )

272-275: Use .default for invoking torch_fake_quant_fp8_linear consistently

The call sites for torch.ops.auto_deploy.torch_fake_quant_fp8_linear are currently inconsistent: some invoke the op directly, while others use the .default attribute. Our convention is to always use .default when calling custom ops. Please update the following locations:

  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py:273
    Change

    return torch.ops.auto_deploy.torch_fake_quant_fp8_linear(
        x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], []
    )

    to

    return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
        x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], []
    )
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py:131
    Similarly prepend .default to the op invocation there.

This will ensure all custom-op calls follow the same pattern.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 7d1f402 and 40ef068.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (4)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (11)
  • FakeFP8Linear (259-275)
  • forward (74-94)
  • forward (106-108)
  • forward (130-135)
  • forward (154-159)
  • forward (184-203)
  • forward (220-222)
  • forward (232-234)
  • forward (247-253)
  • forward (272-275)
  • MLP (97-111)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (2)
  • forward (75-85)
  • forward (119-129)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (7)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
  • run_test_transformed_gm (68-138)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp4_compatible (33-34)
  • fp8_compatible (29-30)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (198-284)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)
  • fp4_global_scale (62-64)
  • fp8_scale (67-69)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
  • torch_fake_quant_fp4_linear (218-276)
  • torch_fake_quant_fp4_linear (280-289)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
  • torch_fake_quant_fp4_linear (218-276)
  • torch_fake_quant_fp4_linear (280-289)
🔇 Additional comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)

147-157: Interface extension looks good; default no-ops are appropriate.

The default kwargs/args hooks are clear and safe. Good separation of concerns for per-format overrides.


165-169: Unified entry point via custom_op() is a solid choice.

Binding to the .default overload makes matching in FX robust and consistent across sites.


239-243: Custom op entry point for FP4 mirrors FP8—good consistency.

Keeps transform plumbing uniform across formats.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1)

47-86: TinyFP8Ref looks good; ensure device/dtype handling remains consistent.

Precomputing weight_fp8/scales and using the reference op aligns with the fusion expectations. Buffers will migrate on .to("cuda"). No functional issues spotted.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

501-509: Ensure consistent operator detection across all methods.

This is the location where the detection inconsistency manifests - the function only checks target_op() but should also check custom_op() to align with the create() method's logic.

See the fix suggested in the earlier comment for lines 94-105.

♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

94-105: Operator detection should be consistent between methods.

The create() method now uses custom_op() for FP4/FP8 detection (lines 98-99) but get_scales_and_type_from_node() (lines 503-504) still only checks target_op(). This inconsistency could cause issues.

Apply this diff to get_scales_and_type_from_node (around lines 503-504):

 def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
     """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc)."""
     for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
-        if is_op(node, qtype.target_op()):
+        if is_op(node, qtype.target_op()) or is_op(node, qtype.custom_op()):
             return extract_scales_from_node(
                 node, qtype.scale_names()
             ), qtype.__name__.lower().replace("quantizationimpl", "")
     return None, "simple"

285-302: Fix copy-paste error in docstring.

The docstring incorrectly mentions "torch_fake_quant_fp8_linear" when it should be "torch_fake_quant_fp4_linear".

Apply this diff:

     @staticmethod
     def build_custom_args_for_linear(  # renamed to reflect args
         scale_getattrs: Dict[str, Node],
     ) -> Tuple[object, ...]:
         """
-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for torch_fake_quant_fp4_linear:
             (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list))

         We pass bias=None to match the exported pattern:
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
+        torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0, args_3_1], [], [])
         """
🧹 Nitpick comments (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)

108-139: Consider extracting common test setup pattern.

The test has similar setup logic to test_fp8_linear, including input generation, weight quantization, and comparison assertions. Consider extracting common test utilities to reduce duplication and improve maintainability.

Additionally, add error handling for when CUDA device is not available:

def _get_test_device():
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")
    return "cuda"

140-197: Add explicit K alignment validation and improve error message.

The assertion at line 156 correctly enforces the NVFP4 block size requirement, but could provide more helpful debugging information.

Apply this diff to improve the assertion:

-    assert K % SCALING_VECTOR_SIZE == 0
+    assert K % SCALING_VECTOR_SIZE == 0, (
+        f"K dimension ({K}) must be a multiple of SCALING_VECTOR_SIZE ({SCALING_VECTOR_SIZE}) "
+        f"for NVFP4 quantization"
+    )

Also consider adding a comment explaining why alpha is computed as 1/(s_in2*s_w2) for future maintainability:

+    # Alpha combines the input and weight per-tensor scales for the fused kernel
     alpha_fused = (1.0 / (s_in2 * s_w2)).to(torch.float32)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

264-284: Good documentation, consider extracting common pattern.

The contract documentation clearly explains the expected argument structure. Both FP8 and FP4 implementations share a similar pattern for building kwargs - consider extracting common logic to reduce duplication.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 40ef068 and f4463a5.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp8_compatible (29-30)
  • fp4_compatible (33-34)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (160-191)
  • torch_fake_quant_fp8_linear (195-205)
  • torch_fake_quant_fp4_linear (209-267)
  • torch_fake_quant_fp4_linear (271-280)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • fp4_global_scale (58-60)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (160-191)
  • torch_fake_quant_fp8_linear (195-205)
  • torch_fake_quant_fp4_linear (209-267)
  • torch_fake_quant_fp4_linear (271-280)
🔇 Additional comments (7)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)

11-11: LGTM! Clear constant naming improvement.

The constant name change from scaling_vector_size to SCALING_VECTOR_SIZE follows Python's uppercase convention for global constants.


50-50: LGTM! Consistent use of the renamed constant.

The update correctly reflects the renamed constant SCALING_VECTOR_SIZE.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (5)

84-84: LGTM! Clear documentation for unsupported feature.

The comment clearly indicates that BMM NVFP4 is not yet supported.


143-154: LGTM! Well-structured base class methods.

The new methods provide a clean interface for custom quantization implementations to supply kernel-specific arguments.


161-165: LGTM! Clear custom operation interface.

The custom_op() method properly returns the FP8 fake quantization operator for unified kernel entry.


180-208: Methods provide consistent FP8 argument handling.

The implementation correctly builds both kwargs and args for the FP8 linear operation, maintaining consistency with the expected signature of torch_fake_quant_fp8_linear.


232-236: LGTM! Consistent FP4 custom operation interface.

The custom_op() method properly returns the FP4 fake quantization operator.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

68-84: Scale buffers are never populated from load_hook output

_default_scales registers submodule buffers input_scale/weight_scale, but FP8/FP4 load_hook currently writes to weight_name + "_scale" in the state_dict, leaving module buffers at defaults (1.0/zeros). The fused/fake ops then read wrong scales.

I recommend updating the load hooks in quantization_utils.py to write into module buffer keys (e.g., f"{modprefix}.weight_scale" and, when derived, f"{modprefix}.input_scale"/f"{modprefix}.alpha"), or add a post-load hook here to transfer state_dict entries into those buffers before execution. Happy to provide a patch in quantization_utils.py (see below comment).

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)

180-187: FP8 load_hook should populate module buffers, not only aux keys

Currently writes weight_name + "_scale" and leaves submodule ".weight_scale" buffer at default. Populate buffer keys so the op reads correct scales.

     def load_hook(state_dict, prefix, *args, weight_name):
         if weight_name in state_dict:
             weight = state_dict[weight_name]
             if weight.dtype != torch.float8_e4m3fn:
                 scale = fp8_scale(state_dict[weight_name])
-                state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn)
-                state_dict[weight_name + "_scale"] = scale
+                state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn)
+                mod_prefix = weight_name.rsplit(".", 1)[0]
+                # Ensure module buffers get correct values
+                state_dict[f"{mod_prefix}.weight_scale"] = scale
+                # input_scale often defaults to 1.0 here; set if missing
+                state_dict.setdefault(f"{mod_prefix}.input_scale", torch.tensor(1.0))

241-287: FP4 load_hook should also set module buffers

Similar issue: ensure submodule buffers (input_scale/weight_scale/alpha) are written so execution uses the right scales.

     def load_hook(state_dict, prefix, *args, weight_name):
         if weight_name in state_dict:
             input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale"
             alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha"
+            weight_scale_buf_name = weight_name.rsplit(".", 1)[0] + ".weight_scale"
             weight = state_dict[weight_name]
@@
-                state_dict[weight_name + "_scale"] = weight_scale
+                # Also populate module buffer key for execution
+                state_dict[weight_scale_buf_name] = weight_scale
@@
-                if (
+                if (
                     weight_name + "_scale_2" in state_dict
                     and weight_name + "_scale" in state_dict
                     and input_scale_name in state_dict
                     and float4_sf_dtype
                 ):
@@
-                    state_dict[weight_name + "_scale"] = (
+                    converted = (
                         torch.ops.trtllm.block_scale_interleave(
                             weight_scale.view(torch.uint8).cpu().contiguous()
                         )
                         .reshape(ori_shape)
                         .view(float4_sf_dtype)
                         .reshape(-1)
                     )
+                    state_dict[weight_name + "_scale"] = converted
+                    state_dict[weight_scale_buf_name] = converted

403-417: Exclude-pattern helper uses linear-specific extractor for BMM

For BMM nodes, extract_param_names_from_lin_node is invalid. Guard by op type and derive module name from the get_attr weight when available.

 def should_skip_quantization(
     node_or_name: Union[Node, str],
     excluded_patterns: list[str],
 ) -> bool:
@@
-    else:
-        if not (is_linear_op(node_or_name, include_quantization=False) or is_bmm_op(node_or_name)):
-            return True
-        param_name, _ = extract_param_names_from_lin_node(node_or_name)
-        modname, _, _ = param_name.rpartition(".")
+    else:
+        if is_linear_op(node_or_name, include_quantization=False):
+            param_name, _ = extract_param_names_from_lin_node(node_or_name)
+            modname, _, _ = param_name.rpartition(".")
+        elif is_bmm_op(node_or_name):
+            wt = node_or_name.args[1]
+            if getattr(wt, "op", None) == "get_attr":
+                modname, _, _ = wt.target.rpartition(".")
+            else:
+                # dynamic weight; no owning module — don't exclude
+                return False
+        else:
+            return True
♻️ Duplicate comments (5)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

300-318: quantization_cb re-registers existing buffers => runtime error

Calling register_buffer on existing names raises; update in place when present and validate buffer existence. This mirrors prior feedback.

-        for scale_name in self.scale_names():
-            scales[scale_name] = submod.get_buffer(scale_name)
+        for scale_name in self.scale_names():
+            try:
+                buf = submod.get_buffer(scale_name)
+            except Exception:
+                buf = None
+            if buf is None:
+                raise RuntimeError(
+                    f"Expected buffer '{scale_name}' on module '{type(submod).__name__}'"
+                )
+            scales[scale_name] = buf
@@
-        for k, v in sharded_scales.items():
-            submod.register_buffer(k, v)
+        for k, v in sharded_scales.items():
+            if k in submod._buffers:
+                setattr(submod, k, v)
+            else:
+                submod.register_buffer(k, v)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

212-247: Mirror finalize/flag logic (prior feedback)

This is the same suggestion previously raised for the BMM path; applies here as well.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (3)

44-46: Replace Python 3.10 union syntax with Optional for Py3.8 compatibility.

Keep repo-wide Python 3.8 support.

-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,

(Apply in both function signatures.)

Also applies to: 99-100


81-93: Avoid boolean*uint8 mix, don’t shadow built-ins, keep inputs unmodified.

Use logical AND with a boolean mask; rename ord/round.

-    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device)
+    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8, device=device)
@@
-    ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8)
+    ordinal = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8)
@@
-    round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1)
-    fp4_val = (sign_bit * 0b1000 + ord + round).to(torch.uint8)
+    round_up = torch.any(
+        (weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) & mask.bool(),
+        dim=-1,
+    )
+    fp4_val = (sign_bit * 0b1000 + ordinal + round_up).to(torch.uint8)

1-1: Add NVIDIA SPDX/Apache-2.0 header (2025) at file top.

Required by project guidelines for all source files.

Apply this diff at the very beginning of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
🧹 Nitpick comments (18)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

1-1: Add NVIDIA Apache-2.0 header

Source files require the header.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

446-449: Broaden resolver to also match fused ops and guard missing symbols

Make TP_SHARDING_RULES robust across fake and fused paths and import-time safe.

-TP_SHARDING_RULES = [
-    (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear), FP8TPShardingInfo),
-    (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear), FP4TPShardingInfo),
-]
+TP_SHARDING_RULES = [
+    (
+        lambda n: is_op(
+            n,
+            (
+                getattr(torch.ops.auto_deploy, "torch_fake_quant_fp8_linear", None),
+                getattr(torch.ops.auto_deploy, "torch_quant_fp8_linear", None),
+            ),
+        ),
+        FP8TPShardingInfo,
+    ),
+    (
+        lambda n: is_op(
+            n,
+            (
+                getattr(torch.ops.auto_deploy, "torch_fake_quant_fp4_linear", None),
+                getattr(torch.ops.auto_deploy, "torch_quant_fp4_linear", None),
+            ),
+        ),
+        FP4TPShardingInfo,
+    ),
+]
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (3)

1-1: Add NVIDIA Apache-2.0 header

Required for non-test source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

136-142: Use Python 3.8-compatible typing

Replace X | None with Optional[X] and import Optional.

-from typing import Tuple
+from typing import Optional, Tuple
@@
-def _fp4_ref_repl_2(
+def _fp4_ref_repl_2(
     x: torch.Tensor,
     w_fp4: torch.Tensor,
-    bias: torch.Tensor | None,
+    bias: Optional[torch.Tensor],

259-269: Finalize graph after rewrites

Eliminate dead code and recompile for stability and downstream passes.

-        num_matches = patterns.apply(gm.graph)
+        num_matches = patterns.apply(gm.graph)
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2)

1-1: Add NVIDIA Apache-2.0 header

Required for source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

174-209: Return skipped=True when no matches; recompile on changes

Align semantics across transforms.

@@ class LinearQuantizationFromConfig(BaseTransform):
-        info = TransformInfo(
-            skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
-        )
-        return gm, info
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+        return gm, TransformInfo(
+            skipped=(num_matches == 0),
+            num_matches=num_matches,
+            is_clean=False,
+            has_valid_shapes=True,
+        )
@@ class BMMQuantizationFromConfig(BaseTransform):
-        info = TransformInfo(
-            skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
-        )
-        return gm, info
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+        return gm, TransformInfo(
+            skipped=(num_matches == 0),
+            num_matches=num_matches,
+            is_clean=False,
+            has_valid_shapes=True,
+        )

Also applies to: 244-247

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

1-1: Add NVIDIA Apache-2.0 header

Required for source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

438-446: Return type annotation doesn’t reflect None possibility

Function can return (None, "simple"); update typing or avoid None returns.

-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
+from typing import Optional
+def get_scales_and_type_from_node(node: Node) -> Tuple[Optional[Dict[str, Node]], str]:
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (6)

12-13: Set explicit float dtype for FP4 tables.

Avoid implicit int64 for e2m1_values; ensure math stays in FP.

-e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
-e2m1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6])
+e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32)
+e2m1_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
+                            0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], dtype=torch.float32)

31-38: Remove unused parameter from _dequant_weight_fp8 and update call site.

out_features is not used. Simplify signature and callers.

-def _dequant_weight_fp8(
-    weight_fp8: torch.Tensor,
-    weight_scale: torch.Tensor,
-    out_features: int,
-    dtype: torch.dtype,
-) -> torch.Tensor:
-    return weight_fp8.to(dtype) * weight_scale
+def _dequant_weight_fp8(
+    weight_fp8: torch.Tensor,
+    weight_scale: torch.Tensor,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    return weight_fp8.to(dtype) * weight_scale

And adjust usage below:

-    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)
+    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, in_dtype)

133-157: Add basic shape assertions to dequant path.

Fail-fast on malformed inputs to ease debugging.

 def _dequantize_nvfp4(
@@
 ) -> torch.Tensor:
     device = quantized_t.device
     N, K = orig_shape
+    assert K % 16 == 0, "NVFP4 dequant requires K to be a multiple of 16."
+    assert quantized_t.shape[-2] == N and quantized_t.shape[-1] == K // 2, \
+        f"quantized_t has shape {tuple(quantized_t.shape[-2:])}, expected {(N, K//2)}"

159-206: FP8 fake op: minor robustness, ensure scales are on same device.

Move scales to input.device to avoid device mismatch if provided as CPU tensors.

-    s_in = _expect_single_scale(input_scale, "input_scale")
-    s_w = _expect_single_scale(weight_scale, "weight_scale")
+    s_in = _expect_single_scale(input_scale, "input_scale").to(input.device)
+    s_w = _expect_single_scale(weight_scale, "weight_scale").to(input.device)

208-281: FP4 eager path: validate scale list lengths and devices.

Guard early and pin scales to device.

-    if len(weight_scale) < 2 or weight_scale[0] is None or weight_scale[1] is None:
+    if len(weight_scale) < 2 or weight_scale[0] is None or weight_scale[1] is None:
         raise ValueError(
             "NVFP4 needs weight_scale[0] (per-block vector) and weight_scale[1] (alpha)."
         )
-    cutlass_qscale = weight_scale[0]
-    alpha = weight_scale[1]
+    cutlass_qscale = weight_scale[0].to(input.device)
+    alpha = weight_scale[1].to(input.device)

11-14: Optional: treat FP tables/constants as UPPER_SNAKE_CASE.

Matches repo style for constants.

-# FP4 tables (E2M1)
-e2m1_bounds = ...
-e2m1_values = ...
+# FP4 tables (E2M1)
+E2M1_BOUNDS = ...
+E2M1_VALUES = ...

(Propagate symbol rename locally.)

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)

142-151: Defensive buffer loading to avoid hard failures.

If any expected buffer is missing, skip fusion for this group instead of raising.

-        for weight_key in keys_unfused:
-            key = weight_key.rsplit(".", 1)[0]
-            for scale_name in flat_scale_names:
-                buffer_name = key + "." + scale_name
-                scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name))
+        try:
+            for weight_key in keys_unfused:
+                key = weight_key.rsplit(".", 1)[0]
+                for scale_name in flat_scale_names:
+                    buffer_name = key + "." + scale_name
+                    scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name))
+        except (AttributeError, KeyError) as e:
+            ad_logger.warning(f"Missing quant buffers for {keys_unfused}, skipping fusion: {e}")
+            return

166-173: Prefer graph.get_attr over create_node('get_attr') for buffers.

Keeps FX graph construction consistent with parameter fetch above.

-            scale_getattrs: Dict[str, Node] = {
-                name: gm.graph.create_node("get_attr", f"{key_fused}_{name}")
-                for name in flat_scale_names
-            }
+            scale_getattrs: Dict[str, Node] = {
+                name: gm.graph.get_attr(f"{key_fused}_{name}")
+                for name in flat_scale_names
+            }

270-291: FP8 fuse_rule: scale aggregation via max is OK; ensure device alignment.

Small tweak to avoid device mismatch when stacking scales.

-        new_weight_scale = torch.max(torch.stack(weight_scale))
+        ws = [s.to(weights[0].device) for s in weight_scale]
+        new_weight_scale = torch.max(torch.stack(ws))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f4463a5 and 0974df5.

📒 Files selected for processing (11)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
📚 Learning: 2025-08-27T16:22:10.642Z
Learnt from: Fridah-nv
PR: NVIDIA/TensorRT-LLM#7227
File: tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py:94-100
Timestamp: 2025-08-27T16:22:10.642Z
Learning: When there are inconsistent operator detection methods (like custom_op() vs target_op()), removing one method and standardizing on the other is often cleaner than supporting both methods simultaneously.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (6)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)
  • ADPatternMatcherPass (96-102)
  • register_ad_pattern (134-217)
  • apply (99-102)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • BaseTransform (139-376)
  • SharedConfig (51-57)
  • TransformInfo (108-133)
  • TransformRegistry (379-407)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
  • is_op (183-206)
  • extract_param_names_from_lin_node (149-170)
  • get_op_overload_packet (173-180)
  • is_linear_op (240-252)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
  • cuda_memory_tracker (10-26)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (10)
  • FakeFP8Linear (259-275)
  • forward (74-94)
  • forward (106-108)
  • forward (130-135)
  • forward (154-159)
  • forward (184-203)
  • forward (220-222)
  • forward (232-234)
  • forward (247-253)
  • forward (272-275)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
  • FP8TPShardingInfo (331-375)
  • SplitDimension (178-182)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_linear_op (240-252)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (48-49)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7)
  • QuantizationImpl (68-142)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (43-55)
  • modelopt_fp4_scale_to_cutlass_fp4_scale (31-40)
  • scale_names (114-116)
  • scale_names (157-158)
  • scale_names (208-209)
  • scale_names (357-358)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (4)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (126-128)
  • build_custom_args_for_linear (292-299)
  • build_custom_args_for_linear (338-345)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (10)
  • build_custom_args_for_linear (139-142)
  • build_custom_args_for_linear (165-177)
  • build_custom_args_for_linear (230-238)
  • target_op (104-106)
  • target_op (147-148)
  • target_op (199-200)
  • target_op (347-348)
  • QuantizationImpl (68-142)
  • create (72-101)
  • should_skip_quantization (403-416)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • TransformRegistry (379-407)
  • register (385-392)
  • BaseTransform (139-376)
  • get (395-397)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • is_linear_op (240-252)
  • is_bmm_op (255-262)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (126-128)
  • build_custom_args_for_linear (292-299)
  • build_custom_args_for_linear (338-345)
🔇 Additional comments (6)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (2)

82-93: Add FP8 MLP test module: looks good

Class mirrors MLP correctly using FakeFP8Linear; forward shape/activation path matches expectations.


331-333: Param additions LGTM

Including FP8MLP in sharding tests and focusing pattern-detection on FP8 is fine for this PR scope.

Also applies to: 349-353

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)

31-38: dequant_weight_fp8 call sites verified; no updates needed

The only invocation of _dequant_weight_fp8 is in

  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py:186
    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)

which exactly matches the updated parameter order and count. No other references were found across the codebase, so no further changes are required.

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)

86-97: check_same_children: implementation is sound.

Converts users to a stable list before iteration and enforces unanimous type; good helper for safe fusion.


201-205: Bias-less constraint is explicit; good.

The bias==None guard simplifies fusion guarantees; matches non-quant path semantics.


316-337: FP4 fuse_rule: concatenating weight scales assumes identical layout.

Document/verify that per-block vectors are already in fused order; otherwise, derive via recorded split sizes.

Would you confirm that weight_scale vectors are laid out as [out_i blocks] per weight, matching cat(weights, dim=0)? If not, we should reorder before concat.

Comment on lines +57 to +58
"Weight shape is not divisible for block size for block quantiation."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix user-facing typo in assert message.

“quantiation” → “quantization”.

-        "Weight shape is not divisible for block size for block quantiation."
+        "Weight shape is not divisible for block size for block quantization."
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"Weight shape is not divisible for block size for block quantiation."
)
raise ValueError(
"Weight shape is not divisible for block size for block quantization."
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py around lines 57 to
58, fix the typo in the assertion message: change "quantiation" to
"quantization" so the user-facing message reads "Weight shape is not divisible
for block size for block quantization." Update only the string literal to
correct the spelling.

Comment on lines +117 to 119
elif model_cls == FP8MLP:
model = model_cls(num_features, num_features, bias=bias).to("cuda")
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid dtype casting FP8 weights during pattern-detection path

In _run_pattern_detection_job, FP8MLP currently falls into the generic else branch (Lines 210-220) which calls .to(dtype=torch.float16) and may upcast FakeFP8Linear weights, risking export/pattern failures. Mirror the special-case used in _run_job.

Apply:

@@
-    if model_cls == GQA_Block:
+    if model_cls == GQA_Block:
         ...
-    else:
+    elif model_cls == FP8MLP:
+        model = model_cls(num_features, num_features, bias=bias).to("cuda")
+    else:
         model = model_cls(num_features, num_features, bias=bias).to(
             device="cuda", dtype=torch.float16
         )

Also applies to: 286-306

🤖 Prompt for AI Agents
In
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
around lines 117-119 (and similarly around lines 286-306), the FP8MLP case is
falling into the generic else that calls .to(dtype=torch.float16), which will
cast FakeFP8Linear weights and break pattern-detection/export; add a
special-case branch for model_cls == FP8MLP that mirrors the handling used in
_run_job by moving the model to CUDA without changing dtype (e.g., .to("cuda")
only) and ensure no .to(dtype=torch.float16) is applied to FP8MLP in both
locations so FakeFP8Linear weights remain untouched during pattern detection.

@Fridah-nv Fridah-nv changed the title [None][autodeploy] Quantization Transforms with Inheritance [#5861][autodeploy] Quantization Transforms with Inheritance Aug 27, 2025
) -> torch.Tensor:
"""
Reference (eager) implementation for multiple quant formats via `format_type`.
For FP4:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, nvfp4 does dual scaling. how is that represented? is that s_w2 ?

Copy link
Collaborator

@suyoggupta suyoggupta Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, is this op handling mxfp4 as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvfp4 does dual scaling. how is that represented? is that s_w2

Yes, the two scales for nvfp4 is q_per_block_scale_w and s_w2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This op does not handle MXFP4 because it has quite different format than NVFP4,
MXFP4 does not have the second scale, its block size is 32, and FP8 scaler format is E8M0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll rename this op to NVFP4

@suyoggupta
Copy link
Collaborator

Thanks @Fridah-nv . Did you get a chance to test this on a sharded quantized model? maybe llama 70B Fp8 and/or llama4-fp8?

stage: post_load_fusion
fuse_fp4_gemms:
stage: post_load_fusion
fuse_fp8_gemms:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to specify fusion for different datatypes separately in the config? Maybe internally "fuse_gemms" can decide which fusions to perform? Alternatively, maybe consider similar approach in sharding (

sharding_dims: ['tp', 'ep', 'dp']
):
fuse_gemms:
stage: sharding
dtypes: ['fp16', 'fp8', 'fp4']

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. Is there the needs to keep these quantization special classes at same abstract class level from different transforms?
For fusions, by default we should enable transforms for all dtypes because there can be mixed dtype gemms in the graph. We cannot dispatch to one of the fuse_gemms implementation at transform level. If we want to put the different dtype cases inside one Transform, it still need to iterate through all types, and we need some abstract class to inherent from. In this case, it's simpler to add QuantizationMixin class to Transform class.
I think sharding's case is different because TP,DP,EP sharding don't share much common code/utilities

return all(is_desired_child(u) for u in users)


class QuantizationFusionMixin:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this class be AbstractBaseClass?

@@ -99,6 +114,8 @@ def _run_job(
hidden_size=num_features,
num_key_value_heads=num_key_value_heads,
).to(device="cuda", dtype=torch.float16)
elif model_cls == FP8MLP:
model = model_cls(num_features, num_features, bias=bias).to("cuda")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can drop dtype=torch.float16 entirely from the "non-quantized" models, and then, we wouldn't need a separate elif for FP8MLP?

# (MLP, "torch_dist_all_reduce"),
(FP8MLP, "torch_dist_all_reduce"),
# (nn.Linear, "torch_dist_all_gather"),
# (GQA_Block, "torch_dist_all_reduce"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the remaining model cases commented out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just testing, forgot to revert it...

@Fridah-nv Fridah-nv requested a review from meenchen August 29, 2025 19:16
@Fridah-nv
Copy link
Collaborator Author

Thanks @Fridah-nv . Did you get a chance to test this on a sharded quantized model? maybe llama 70B Fp8 and/or llama4-fp8?

Yes, I tested TP sharding for FP8 and NVFP4 e2e. I'm planning to test quantized MoE and BMM models for sanity check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Backlog
Development

Successfully merging this pull request may close these issues.

3 participants