Skip to content

Conversation

yilin-void
Copy link
Collaborator

@yilin-void yilin-void commented Jul 23, 2025

DeepEP diff:https://github.com/deepseek-ai/DeepEP/compare/7b15af835942675df041eca2dcb9930b880287e1...edf3ea2b086a393d3163bf2773eab69d9191cc01?expand=1

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features

    • Introduced support for low-latency dispatch of FP4 (4-bit floating point) data, enabling enhanced processing for specific model operations.
  • Bug Fixes

    • Corrected tensor dimension assignments and updated dispatch handling for improved data processing accuracy.
  • Chores

    • Updated the DeepEP dependency to a newer version.

Copy link
Contributor

coderabbitai bot commented Jul 23, 2025

📝 Walkthrough

"""

Walkthrough

The changes update the DeepEP dependency version in the CMake configuration, add a new low_latency_dispatch_fp4 method to the VariableLengthLowLatencyBuffer class for FP4 data dispatch, and modify the fused MoE wide expert module to correctly handle tensor dimensions and use the new FP4 dispatch method.

Changes

File(s) Change Summary
cpp/tensorrt_llm/deep_ep/CMakeLists.txt Updated DeepEP dependency commit hash to a newer version.
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py Added low_latency_dispatch_fp4 method to support FP4 dispatch in VariableLengthLowLatencyBuffer.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py Replaced dispatch call with low_latency_dispatch_fp4; updated tensor shape assertions and reshaping logic to handle FP4 data.

Sequence Diagram(s)

sequenceDiagram
    participant ForwardChunk
    participant VariableLengthLowLatencyBuffer
    participant DeepEPBuffer

    ForwardChunk->>VariableLengthLowLatencyBuffer: low_latency_dispatch_fp4(hidden_states, scales, topk_idx, ...)
    VariableLengthLowLatencyBuffer->>DeepEPBuffer: low_latency_dispatch_fp4(hidden_states, scales, topk_idx, ...)
    DeepEPBuffer-->>VariableLengthLowLatencyBuffer: (recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook)
    VariableLengthLowLatencyBuffer-->>ForwardChunk: (recv_hidden_states, recv_scales, recv_expert_count, handle)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~15 minutes

Possibly related PRs

Suggested reviewers

  • yuantailing
  • hyukn

Poem

In the warren where tensors leap and play,
A new FP4 dispatch hops in today!
With scales and states, it swiftly flows,
Through DeepEP paths the data goes.
A hash is nudged, a buffer's grown—
The code is ready, carrots thrown! 🥕
"""

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.


📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a854cc6 and d9df261.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2 hunks)
✅ Files skipped from review due to trivial changes (1)
  • cpp/tensorrt_llm/deep_ep/CMakeLists.txt
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

165-165: Line too long (128 > 120)

(E501)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

605-605: Line too long (132 > 120)

(E501)

🔇 Additional comments (3)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)

157-174: LGTM: Well-implemented FP4 dispatch method.

The new low_latency_dispatch_fp4 method follows the same pattern as the existing low_latency_dispatch method and correctly handles the additional scales parameter for FP4 operations. The implementation properly asserts preconditions, calls the underlying buffer method, and returns the expected values including the new recv_scales.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2)

593-599: LGTM: Proper FP4 tensor validation.

The assertions correctly validate the FP4 tensor format requirements:

  • Ensures both tensors are uint8 (appropriate for FP4 storage)
  • Validates hidden_size alignment (divisible by 32)
  • Confirms expected tensor shapes for FP4 format

The validation logic is consistent with FP4 processing requirements.


620-622: LGTM: Correct tensor reshaping for FP4 processing.

The tensor reshaping operations properly transform the 3D tensors returned from the FP4 dispatch into the 2D format expected by the downstream fused_moe operation. The reshaping logic is consistent with the processing pipeline requirements.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)

157-174: LGTM! Well-implemented FP4 dispatch method with minor formatting issue.

The new low_latency_dispatch_fp4 method follows the established patterns and correctly extends the functionality to support FP4 data with scales. The implementation is consistent with the existing low_latency_dispatch method.

However, please fix the line length issue on line 166:

-        recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \
-            self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts)
+        recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \
+            self.buffer.low_latency_dispatch_fp4(
+                hidden_states, scales, topk_idx, 
+                num_max_dispatch_tokens_per_rank, num_experts)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2b0fa24 and 04410f2.

📒 Files selected for processing (2)
  • cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1 hunks)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

166-166: Line too long (128 > 120)

(E501)

🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

166-166: Line too long (128 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1)

1-1: DeepEP commit hash verified

  • Commit 5ac330ba4157b46c439e7575a9b558d2b01224e6 exists in the DeepEP repo.
  • Merge message confirms “LL dispatch FP4”.
  • FP4 functionality is present in tests/test_low_latency_fp4.py.

This update is consistent with the PR objectives. Approved.

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)

166-166: DeepEP buffer method verified

The low_latency_dispatch_fp4 method is present in the DeepEP codebase at the specified commit:

  • In csrc/deep_ep.hpp (line 146): method signature
  • In csrc/deep_ep.cpp (line 1187): method definition
  • In csrc/deep_ep.cpp (line 1440): Python binding via .def

No further action is required.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)

605-605: Consider breaking long line for better readability.

The line exceeds the 120 character limit. Consider breaking it for better readability.

-                x, x_sf, recv_expert_count, deep_ep_handle = \
-                    self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
+                x, x_sf, recv_expert_count, deep_ep_handle = \
+                    self.deep_ep_buffer.low_latency_dispatch_fp4(
+                        x, x_sf, deep_ep_topk_idx, 
+                        all_rank_max_num_tokens, self.num_slots)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 04410f2 and 37d06c1.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

605-605: Line too long (132 > 120)

(E501)

🔇 Additional comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (3)

591-592: Critical dimension assignment fix.

The correction of swapping token_num and hidden_size assignments is appropriate. The previous assignment was incorrect as x_row (shape[0]) represents the token dimension while x_col (shape[1]) represents the hidden dimension.


595-595: Correct FP4 data type handling and dispatch method.

The changes correctly handle FP4 quantized data:

  • Using torch.uint8 dtype for packed FP4 values is appropriate
  • The dimension assertions are updated to match the corrected variable assignments
  • The new low_latency_dispatch_fp4 method provides a cleaner interface for FP4 dispatch

Also applies to: 598-599, 606-609


619-622: Reshape dimensions are consistent with low_latency_dispatch_fp4
The preceding assertions in fused_moe_wide_ep.py (x.shape[2] == hidden_size // 2 and x_sf.shape[2] == hidden_size // 16 // 2) confirm the 3D layouts for hidden states and scales. Flattening them to [batch_size * seq_length, feature_dim] matches the buffer call and requires no further changes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)

604-609: Approve FP4 dispatch method integration, but fix line length

The direct call to low_latency_dispatch_fp4 is an architectural improvement over the previous temporary packed tensor approach. However, there's a line length violation.

Apply this diff to fix the line length issue:

-                x, x_sf, recv_expert_count, deep_ep_handle = \
-                    self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
+                x, x_sf, recv_expert_count, deep_ep_handle = self.deep_ep_buffer.low_latency_dispatch_fp4(
+                    x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots
+                )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 37d06c1 and dada91e.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

605-605: Line too long (132 > 120)

(E501)

🔇 Additional comments (4)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (4)

591-592: LGTM: Variable assignment correction

The swap of token_num and hidden_size assignments correctly aligns with tensor dimension semantics where x.shape[0] represents the token count and x.shape[1] represents the hidden size.


595-598: LGTM: Assertions updated for corrected variable semantics

The assertions correctly validate tensor shapes and data types using the corrected token_num and hidden_size variable assignments.


619-621: LGTM: Tensor reshaping updated for 3D structure

The reshaping operations correctly flatten the first two dimensions of the 3D tensors returned by the new FP4 dispatch method, preparing them for the fused MoE kernel.


590-626: Excellent architectural improvement for FP4 dispatch

The overall changes represent a well-executed architectural improvement:

  1. Corrected semantics: Variable assignments now properly reflect tensor dimensions (token_num = x_col, hidden_size = x_row)
  2. Cleaner architecture: Direct low_latency_dispatch_fp4 call eliminates the need for temporary packed tensors
  3. Consistent updates: All assertions, reshaping operations, and tensor handling have been updated to maintain consistency
  4. Maintained functionality: The changes preserve the existing behavior while improving code clarity

This modernization of the FP4 dispatch logic enhances maintainability and aligns with the broader DeepEP integration improvements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)

606-609: Consider documenting the FP4 tensor layout.

The post-dispatch assertions correctly validate the 3D tensor format, but the FP4 packing scheme (2 values per uint8) could benefit from documentation for future maintainability.

Add a comment explaining the FP4 layout:

                assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
                assert x.dim() == 3 and x_sf.dim() == 3
+               # FP4 packing: each uint8 holds 2 FP4 values, hence hidden_size // 2
+               # Scaling factors: one per 16 FP4 values, hence hidden_size // 16
                assert x.shape[2] == hidden_size // 2 and x_sf.shape[2] == hidden_size // 16
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dada91e and 46580cf.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

605-605: Line too long (132 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2)

593-599: LGTM! Proper validation for FP4 tensor format.

The assertions correctly validate the FP4 tensor format requirements:

  • Ensures scaling factors are present when using FP4
  • Validates uint8 data type (as FP4 values are packed into uint8)
  • Confirms proper shape constraints where scaling factors are 1/16th of hidden size

619-621: LGTM! Correct tensor reshaping for fused_moe operation.

The reshape operations properly flatten the 3D tensors from the dispatch output into the 2D format expected by the downstream fused_moe operation.

@yilin-void yilin-void force-pushed the deep_ep/fp4_dispatch branch from 46580cf to df6b097 Compare July 24, 2025 04:08
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)

604-605: Fix line length issue.

Line 605 exceeds the 120 character limit.

Split the function call across multiple lines:

-                x, x_sf, recv_expert_count, deep_ep_handle = \
-                    self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
+                x, x_sf, recv_expert_count, deep_ep_handle = \
+                    self.deep_ep_buffer.low_latency_dispatch_fp4(
+                        x, x_sf, deep_ep_topk_idx, 
+                        all_rank_max_num_tokens, self.num_slots)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 46580cf and df6b097.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/deep_ep/CMakeLists.txt (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2 hunks)
✅ Files skipped from review due to trivial changes (1)
  • cpp/tensorrt_llm/deep_ep/CMakeLists.txt
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

605-605: Line too long (132 > 120)

(E501)

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

166-166: Line too long (128 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py (1)

157-174: LGTM! FP4 dispatch method follows the established pattern.

The new low_latency_dispatch_fp4 method correctly extends the low-latency dispatch functionality for FP4 data, properly handling separate scale tensors alongside hidden states.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2)

593-598: Comprehensive FP4 data validation.

The assertions properly validate FP4 data format requirements, ensuring correct tensor types and shape constraints for the dispatch operation.


606-626: Correct FP4 post-dispatch processing.

The code properly handles the dispatched FP4 data by validating output formats, reshaping tensors, and applying the necessary swizzle operation for optimal memory layout.

@yilin-void yilin-void force-pushed the deep_ep/fp4_dispatch branch from df6b097 to 5817c7f Compare July 24, 2025 06:48
@yilin-void yilin-void requested review from yuantailing and hyukn July 24, 2025 06:53
@yilin-void yilin-void marked this pull request as ready for review July 24, 2025 06:53
@yilin-void yilin-void requested a review from a team as a code owner July 24, 2025 06:53
@yilin-void yilin-void requested review from liji-nv and removed request for liji-nv July 24, 2025 06:53
@yilin-void yilin-void force-pushed the deep_ep/fp4_dispatch branch from 5817c7f to a854cc6 Compare July 24, 2025 06:54
@yilin-void
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12820 [ run ] triggered by Bot

Copy link
Collaborator

@hyukn hyukn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12820 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9554 completed with status: 'SUCCESS'

Signed-off-by: Yilin Zhang <[email protected]>
@yilin-void yilin-void force-pushed the deep_ep/fp4_dispatch branch from a854cc6 to d9df261 Compare July 28, 2025 02:19
@yilin-void
Copy link
Collaborator Author

/bot reuse-pipeline

@yilin-void yilin-void enabled auto-merge (squash) July 28, 2025 02:19
@coderabbitai coderabbitai bot requested review from hyukn and yuantailing July 28, 2025 02:20
@tensorrt-cicd
Copy link
Collaborator

PR_Github #13133 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13133 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #12820 for commit d9df261

@yilin-void yilin-void merged commit f172fac into NVIDIA:main Jul 28, 2025
3 checks passed
NVShreyas pushed a commit to NVShreyas/TensorRT-LLM that referenced this pull request Jul 28, 2025
Signed-off-by: Yilin Zhang <[email protected]>
Signed-off-by: Shreyas Misra <[email protected]>
Ransiki pushed a commit to Ransiki/TensorRT-LLM that referenced this pull request Jul 29, 2025
Signed-off-by: Yilin Zhang <[email protected]>
Signed-off-by: Ransiki Zhang <[email protected]>
lancelly pushed a commit to lancelly/TensorRT-LLM that referenced this pull request Aug 6, 2025
Signed-off-by: Yilin Zhang <[email protected]>
Signed-off-by: Lanyu Liao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants