-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[None][feat] Add support for Hopper MLA chunked prefill #6655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[None][feat] Add support for Hopper MLA chunked prefill #6655
Conversation
📝 WalkthroughWalkthroughSoftmax statistics layout changed to store two floats per entry ([max, sum]) and propagated through kernels, host copies, checks, and runner params; kernel enumeration/gating for MLA and return_softmax adjusted; DMA Q-loop guarded for zero KV length; tests and CI updated for chunked prefill (SM90 BF16 cases); executor gating extended to allow SM90. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as Test/CLI
participant Runner as FusedMHARunnerV2
participant Kernel as FMHA Kernel
participant Softmax as Softmax reducer/saver
participant Host as Host verifier
CLI->>Runner: run with -save-softmax (MLA/SM90 BF16 / paged/separate)
Runner->>Kernel: launch with softmax_stats_ptr, stride = h * sizeof(float2)
Kernel->>Softmax: compute per-token/head max & sum
Softmax-->>Kernel: store [max,sum] into softmax_stats buffer
Kernel-->>Runner: kernel complete
Runner->>Host: copy 2 * b * s * h floats D2H
Host->>Host: check_softmax_results compares max and sum (returns pair of error counts)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (6)
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (2)
1-17
: Missing NVIDIA copyright & license headerPer project coding-guidelines, every
*.cu/.cpp/.h/.py
file must start with an NVIDIA copyright header for the current year.
Add the standard header block before the first include.
190-193
: Hard-coded soft-max scale constant diverges from computedbmm1_scale
bmm1_scale
is computed earlier (Line 135) but never used; instead an unexplained magic value0.072168784
is applied.
This makes the reference kernel inconsistent with production code and the newly-introducedbmm1_scale
parameter, reducing test fidelity.- P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * 0.072168784); + P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * bmm1_scale);Either use the computed scale or document why a different constant is required.
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
1440-1465
: Add documentation for the newbmm1_scale
parameter.The method
merge_attention_for_mla
lacks a docstring. Please add documentation explaining the purpose of all parameters, especially the newly addedbmm1_scale
parameter and its role in the MLA chunked prefill scaling.Add a docstring to document the method:
def merge_attention_for_mla( self, merged_attn: torch.Tensor, temp_attn: torch.Tensor, softmax_stats: torch.Tensor, temp_softmax_stats: torch.Tensor, bmm1_scale: float, merge_op: torch.Tensor, metadata: TrtllmAttentionMetadata, ) -> None: + """ + Merge attention outputs for MLA (Multi-Level Attention) with chunked prefill. + + Args: + merged_attn: The accumulated attention output tensor. + temp_attn: The temporary attention output from current chunk. + softmax_stats: The accumulated softmax statistics (max and sum). + temp_softmax_stats: The temporary softmax statistics from current chunk. + bmm1_scale: Scaling factor for the first batch matrix multiplication. + merge_op: Operation indicator for merging (0=skip, 1=merge, 2=copy). + metadata: Attention metadata containing context information. + """ assert self.is_mla_enable and self.mla_params is not Nonecpp/tensorrt_llm/kernels/recoverFromRingAtten.cu (1)
113-119
: Critical: Incorrect pointer arithmetic in invokeRecoverFromRA.The function signature has been updated to use combined softmax statistics arrays, but lines 116-119 still use the old separated layout calculation. This will cause incorrect memory access since the kernel expects interleaved max/sum pairs.
Remove the incorrect pointer calculations since the kernel now handles the interleaved layout internally:
template <typename Tout> void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, float* softmax_stats, int b, int s, int h, int d, int* cu_seqlens, cudaStream_t stream) { - float* accu_softmax_sum = accu_softmax_stats; - float* accu_softmax_max = accu_softmax_stats + b * s * h; - float* softmax_sum = softmax_stats; - float* softmax_max = softmax_stats + b * s * h; - int threads_per_block = 128;cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (2)
574-602
: Fix type inconsistency between double and float for bmm1_scale.The function signature uses
double bmm1_scale
but the helper functionmergeChunkedAttentionForMLAHelper
expectsfloat bmm1_scale
, causing implicit type conversion.For consistency and to avoid precision issues, change the parameter type to
float
:-void mergeChunkedAttentionForMLA(torch::Tensor& merged_attn, torch::Tensor const& temp_attn, - torch::Tensor& merged_softmax_stats, torch::Tensor const& temp_softmax_stats, double bmm1_scale, - int64_t const num_requests, torch::Tensor const& cu_q_seq_lens, int64_t const max_q_seq_len, - torch::Tensor const& merge_op, int64_t const num_heads, int64_t const head_size) +void mergeChunkedAttentionForMLA(torch::Tensor& merged_attn, torch::Tensor const& temp_attn, + torch::Tensor& merged_softmax_stats, torch::Tensor const& temp_softmax_stats, float bmm1_scale, + int64_t const num_requests, torch::Tensor const& cu_q_seq_lens, int64_t const max_q_seq_len, + torch::Tensor const& merge_op, int64_t const num_heads, int64_t const head_size)
747-768
: Torch library definition type should match function signature.The library definition uses
float bmm1_scale
while the actual function signature usesdouble bmm1_scale
. This creates an interface mismatch.If the function signature is updated to use
float
as suggested in the previous comment, this definition would be correct. Otherwise, update this definition to match:- ", float bmm1_scale" + ", float bmm1_scale"Actually, keep this as
float
and change the function signature to match, asfloat
precision is typically sufficient for scaling factors and maintains consistency with the helper function.
♻️ Duplicate comments (1)
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (1)
805-807
: Same aliasing issue on final merge callThe final merge call repeats the duplicate-pointer pattern noted above. Fix here as well once the interface is clarified to avoid inconsistent behaviour between intermediate and final merges.
🧹 Nitpick comments (4)
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (1)
754-755
: Unnecessary double-precision cast
std::sqrt(static_cast<double>(…))
immediately narrows tofloat
, incurring an extra conversion.
Use single-precision throughout:- float bmm1_scale = 1.F / std::sqrt(static_cast<double>(this->mNopeSize + this->mRopeSize)); + float bmm1_scale = 1.F / std::sqrt(static_cast<float>(this->mNopeSize + this->mRopeSize));cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1)
457-458
: Consider reusing the existing variable to avoid duplication.The condition for
isHopperBF16ContextMLA
is already computed at lines 426-427. Consider reusing that variable or extracting it to a member/local variable computed once.If the variable at line 426 is in scope, reuse it:
- bool isHopperBF16ContextMLA = (mFixedParams.headSize == mFixedParams.headSizeV + 64) && isSm90 - && mFixedParams.dataType == DATA_TYPE_BF16 && mFixedParams.headSizeV == 128; + // Reuse the variable from line 426 if in scope, or extract to avoid duplicationcpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
876-886
: Consider making MLA detection more robust.While the layout validation logic is correct and improves error handling, the hard-coded dimension check
(d == 192 && dv == 128)
for MLA detection may be brittle and not future-proof.Consider using a more explicit parameter or configuration flag to identify MLA mode rather than relying on specific dimension values.
- bool is_MLA = (d == 192 && dv == 128); + // Consider adding an explicit MLA flag parameter instead of dimension-based detection + bool is_MLA = (d == 192 && dv == 128); // TODO: Replace with explicit MLA parametercpp/kernels/fmha_v2/setup.py (1)
2313-2313
: Fix line length violation while maintaining correct logic.The softmax statistics pointer validation logic is correct, but the line exceeds the 120-character limit.
- il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr ' + il_check += ('&& params.softmax_stats_ptr != nullptr ' + if kspec.return_softmax_stats + else '&& params.softmax_stats_ptr == nullptr ')
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
cpp/kernels/fmha_v2/fmha_test.py
(1 hunks)cpp/kernels/fmha_v2/setup.py
(4 hunks)cpp/kernels/fmha_v2/src/fmha/fragment.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
(1 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
(7 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
(1 hunks)cpp/kernels/fmha_v2/src/softmax_impl.h
(4 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
(2 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
(5 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
(1 hunks)cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
(4 hunks)cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
(5 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
(3 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py
(2 hunks)tensorrt_llm/_torch/modules/attention.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(1 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py
(1 hunks)tests/integration/test_lists/test-db/l0_h100.yml
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (5)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
cpp/kernels/fmha_v2/fmha_test.py
tensorrt_llm/_torch/modules/attention.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tests/integration/defs/accuracy/test_llm_api_pytorch.py
tensorrt_llm/_torch/attention_backend/trtllm.py
cpp/kernels/fmha_v2/setup.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/kernels/fmha_v2/fmha_test.py
tensorrt_llm/_torch/modules/attention.py
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
tests/integration/defs/accuracy/test_llm_api_pytorch.py
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
tensorrt_llm/_torch/attention_backend/trtllm.py
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
cpp/kernels/fmha_v2/src/softmax_impl.h
cpp/kernels/fmha_v2/src/fused_multihead_attention.h
cpp/kernels/fmha_v2/src/fmha/fragment.h
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
cpp/kernels/fmha_v2/setup.py
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable: camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables: camelcase prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-scope magic-number/l...
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
cpp/kernels/fmha_v2/src/softmax_impl.h
cpp/kernels/fmha_v2/src/fused_multihead_attention.h
cpp/kernels/fmha_v2/src/fmha/fragment.h
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard with the guard name prefix TRTLLM_ followed by the filename, all in caps, for header files.
Files:
cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
cpp/kernels/fmha_v2/src/softmax_impl.h
cpp/kernels/fmha_v2/src/fused_multihead_attention.h
cpp/kernels/fmha_v2/src/fmha/fragment.h
**/*.cu
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
CUDA code includes code that must be compiled with a CUDA compiler, such as code with device, managed, constant, global, or kernel launching syntax.
Files:
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
🧠 Learnings (4)
📚 Learning: in tensorrt-llm testing, it's common to have both cli flow tests (test_cli_flow.py) and pytorch api ...
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/integration/test_lists/test-db/l0_h100.yml
cpp/kernels/fmha_v2/fmha_test.py
tests/integration/defs/accuracy/test_llm_api_pytorch.py
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
📚 Learning: in tensorrt_llm/executor/worker.py, the lora adapter cache optimization logic that checks `is_adapte...
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks `is_adapter_in_cpu_cache()` and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: in the tensorrt-llm waive list merging system, removed lines are always located at the end of the me...
Learnt from: yiqingy0
PR: NVIDIA/TensorRT-LLM#5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.
Applied to files:
cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
🧬 Code Graph Analysis (2)
tensorrt_llm/_torch/modules/attention.py (1)
tensorrt_llm/_utils.py (1)
get_sm_version
(674-676)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)
tensorrt_llm/_utils.py (1)
get_sm_version
(674-676)tensorrt_llm/logger.py (1)
warning
(131-132)
🪛 Ruff (0.12.2)
cpp/kernels/fmha_v2/setup.py
2313-2313: Line too long (146 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (37)
tests/integration/test_lists/test-db/l0_h100.yml (1)
43-44
: LGTM! Good test coverage for chunked prefill functionality.The addition of these two test cases provides comprehensive coverage for the
test_chunked_prefill
method with bothkv_cache_reuse=True
andkv_cache_reuse=False
scenarios. This aligns well with the PR's objective of adding Hopper MLA chunked prefill support and ensures the functionality is properly tested on H100 hardware.cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h (1)
503-503
: LGTM! Good defensive programming practice.Adding the
actual_kv_seqlen > 0
condition to the loop prevents unnecessary Q tile loading when there's no KV data to attend to. This runtime guard improves efficiency and prevents potential issues when the KV sequence length is zero or negative.tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
300-301
: LGTM! Correctly expands MLA chunked prefill support to Hopper architecture.The change from SM100-only to both SM90 and SM100 support aligns with the PR objective of adding Hopper MLA chunked prefill support. SM90 corresponds to the Hopper architecture, so this expansion is appropriate.
cpp/kernels/fmha_v2/fmha_test.py (1)
181-197
: LGTM! Test coverage extension for chunked prefill with softmax statistics.The added test cases properly extend coverage for the new chunked prefill feature with softmax statistics saving. The conditional logic correctly targets BF16 data type on SM90 (Hopper) architecture with specific input layouts (
-paged-kv
and-separate-q-k-v
), which aligns with the PR's objective of adding Hopper MLA chunked prefill support.The test includes both padding mask and causal mask scenarios with the
-save-softmax
flag, ensuring comprehensive validation of the softmax statistics handling in chunked prefill contexts.tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
1433-1433
: LGTM! Hardware requirement adjustment aligns with Hopper MLA chunked prefill support.The change from
skip_pre_blackwell
toskip_pre_hopper
correctly adjusts the minimum hardware requirement for the chunked prefill test with "none" quantization. This aligns with the PR's objective of adding Hopper MLA chunked prefill support, enabling the test to run on SM90 (Hopper) architecture instead of requiring the newer Blackwell architecture.cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (1)
774-777
: Duplicate arguments may hide interface misuse
invokeMergeAttnWithSoftmax
is called withd_output_accum_ptr
andd_softmax_sum_accum_ptr
passed twice (as both destination and “old” accumulation buffers).
If the new kernel signature expects separateold
andnew
buffers, aliasing them defeats the purpose and risks unintended in-place mutations.Confirm the intended parameter order and, if necessary, supply distinct pointers.
tensorrt_llm/_torch/modules/attention.py (3)
1195-1197
: LGTM! Hardware-specific scaling implementation is correct.The conditional bmm1_scale computation properly applies the standard attention scaling factor (1/√d_k) for Hopper architecture while maintaining backward compatibility for other SM versions.
1261-1262
: LGTM! Consistent parameter passing for merge operation.The bmm1_scale parameter is correctly passed to the merge_attention_for_mla method within the chunked prefill loop.
1311-1312
: LGTM! Consistent bmm1_scale parameter usage.The second merge_attention_for_mla call correctly includes the bmm1_scale parameter, maintaining consistency with the earlier call in the chunked prefill loop.
cpp/kernels/fmha_v2/src/fused_multihead_attention.h (2)
195-195
: LGTM! Updated layout documentation reflects the new softmax stats storage.The comment correctly describes the new layout
[total_tokens_q, h, 2]
which stores softmax_max and softmax_sum as pairs, improving cache locality compared to the previous[2, B, S, H]
layout.
197-197
: LGTM! Accurate stride documentation for the new layout.The comment correctly specifies the default stride as
h * sizeof(float2)
, which properly reflects the storage of softmax max and sum as float2 pairs per head.cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h (1)
468-478
: Memory layout changes look correct.The updated pointer initialization and offset calculations correctly implement the new interleaved storage layout where max and sum values are stored as consecutive pairs. The stride calculation using
sizeof(float) * 2
and the sum pointer offset bysizeof(float)
from the max pointer are appropriate for this layout.cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh (1)
29-33
: Function signature update looks good.The addition of the
float bmm1_scale
parameter is well-placed in the function signature, and the existing comments clearly document the purpose of the softmax statistics arrays.cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu (1)
32-55
: Kernel refactoring for combined softmax stats looks correct.The kernel has been properly updated to handle the new interleaved storage layout where max and sum values are stored as consecutive pairs. The pointer arithmetic and offset calculations correctly use a stride of 2.
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2)
140-142
: LGTM! Correct stride calculation for paired softmax statistics.The stride calculation correctly accounts for storing both max and sum values (2 floats) per head, and the comment clearly documents the expected memory layout.
459-465
: LGTM! Proper layout validation for MLA context support.The logic correctly differentiates input layout requirements between standard and MLA context phases, aligning with the PR's goal of adding Hopper MLA chunked prefill support. The TODO comment appropriately tracks future work for separate QKV input layout.
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (1)
1320-1336
: LGTM! Sum validation logic is correct.The sum checking loop correctly accesses odd indices (2*h + 1) and validates sum values. The check for
sum_ref != 1.0f
makes sense here as it skips validation for placeholder values.cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (6)
343-343
: LGTM: Correct stride update for dual-float softmax statistics.The stride change from
h * 1
toh * 2
floats correctly reflects the new layout where each head stores both max and sum values per token, aligning with the AI summary description.
1206-1209
: LGTM: Correct allocation size and comment update.The allocation size correctly doubles to accommodate both max and sum values per token-head pair. The comment update clarifies the buffer's purpose, and the memset size matches the allocation.
1222-1223
: LGTM: Host buffer allocations match device buffer layout.The doubled allocation sizes for host reference buffers correctly align with the device buffer changes, ensuring consistent memory layout for validation purposes.
1917-1917
: LGTM: Copy size matches new buffer layout.The doubled copy size correctly transfers both max and sum statistics from device to host for reference validation.
2003-2003
: LGTM: Consistent copy size for kernel output validation.The copy operation correctly transfers the full softmax statistics buffer (both max and sum) from the kernel output for validation purposes.
2024-2025
: LGTM: Enhanced validation logic for dual-component softmax statistics.The updated validation correctly handles the new dual-float statistics by checking both max and sum error counts. The bitwise OR operation properly combines this validation result with the overall status.
cpp/kernels/fmha_v2/src/fmha/fragment.h (3)
1907-1907
: LGTM: Pointer initialization aligns with new paired storage layout.The changes correctly update the pointer initialization to support the new data layout where softmax max and sum values are stored as pairs of floats per token and head. Both pointers now start from the same base address with appropriate offsets -
softmax_max_ptr_
points to the max value andsoftmax_sum_ptr_
points to the sum value (offset bysizeof(float)
).Also applies to: 1919-1921
1941-1941
: LGTM: Store method correctly uses updated pointer variables.The store operations now properly use
softmax_sum_ptr_
for storing sum values, which is consistent with the updated pointer initialization logic in the constructor.Also applies to: 1946-1946
1952-1952
: LGTM: Comment accurately documents the new data layout.The updated comment correctly describes the new
(total_token_q, h, 2)
layout where the last dimension stores paired max and sum float values.cpp/kernels/fmha_v2/src/softmax_impl.h (4)
151-152
: LGTM: Function signature updated to output max value.The addition of the
max_fp32
reference parameter allows the caller to retrieve the computed maximum value from the reduction, which is consistent with the broader changes to track softmax statistics as pairs.
247-248
: LGTM: Function signature updated consistently.The addition of the
max_fp32
reference parameter maintains consistency with the otherreduce
function overloads and enables proper extraction of maximum values during softmax reduction.
416-417
: LGTM: Function signature updated for consistency.The addition of the
max_fp32
reference parameter completes the consistent update across allreduce
function template overloads, enabling unified handling of maximum value extraction.
999-1009
: LGTM: Kernel updated to store both max and sum statistics.The changes correctly implement the new softmax statistics handling:
max_fp32
is properly initialized and passed to thereduce
function- The storage layout
[B, S, H, 2]
is clearly documented and correctly implemented- Index calculations properly interleave max and sum values per head
- Bounds checking ensures safe memory access
The implementation aligns with the broader changes to track softmax statistics as pairs across the FMHA kernels.
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (3)
161-163
: LGTM: Kernel signature updated with dynamic scaling parameter.The addition of the
bmm1_scale
parameter enables dynamic scaling instead of hardcoded values, providing flexibility for different hardware configurations (Blackwell vs Hopper as mentioned in the comment). The parameter placement in the signature is logical and consistent with the function's data flow.
191-191
: LGTM: Dynamic scaling applied to current statistics.The scaling is correctly applied to the max value (
curr_stats.x
) from the current attention statistics. The comment clearly explains the architecture-specific scaling factors (Blackwell vs Hopper), which provides valuable context for maintainers and validates the need for dynamic scaling.
399-402
: LGTM: Function signatures consistently updated across the call stack.The changes maintain API consistency by properly propagating the
bmm1_scale
parameter through:
- The public
invokeMergeAttnWithSoftmax
function signature- The kernel invocation call
- The template instantiation macro for all supported data types
This ensures that the dynamic scaling capability is accessible throughout the entire call hierarchy while maintaining type safety and consistency.
Also applies to: 410-413, 452-456
cpp/kernels/fmha_v2/setup.py (3)
1974-1981
: LGTM! Correct mask type selection for MLA context phase.The logic correctly restricts MLA context phase (head_size=192, head_size_v=128) to only use causal mask and padding mask for chunked prefill, which aligns with the PR objective of adding Hopper MLA chunked prefill support.
1973-1982
: Excellent implementation of Hopper MLA chunked prefill support.The changes successfully implement the required functionality:
- Mask Type Selection: Correctly restricts MLA context phase to causal and padding masks only
- Softmax Statistics: Proper validation of softmax stats pointer based on kernel configuration
- Input Layout Logic: Appropriate differentiation between normal attention and MLA context requirements
- Kernel Enumeration: Comprehensive coverage of different head sizes and configurations
The implementation is well-structured, follows existing patterns, and includes helpful comments explaining the complex logic. This aligns perfectly with the PR objective of adding Hopper MLA chunked prefill support.
Also applies to: 2654-2801
2654-2801
: Complex but well-structured kernel enumeration logic – verification requiredThe implementation correctly distinguishes normal attention vs MLA context when
return_softmax_stats
is enabled:
- Normal attention: only
CONTIGUOUS_Q_KV
input layout- MLA context: only
Q_PAGED_KV
orSEPARATE_Q_K_V
input layoutsHowever, the proposed Python verification script failed with an ImportError due to a missing
bindings
module. To confirm that MLA context kernels (head_size = 192, head_size_v = 128, return_softmax_stats = true) are generated with exactly those layouts, please:
- Build and install the Python package first:
scripts/build_wheel.py && pip install -e .- Re-run the verification script to inspect the
specs
list for the expected kernels.Alternatively, you can directly inspect the conditional blocks around lines 2654–2801 in
cpp/kernels/fmha_v2/setup.py
:rg -n "return_softmax_stats" -C3 cpp/kernels/fmha_v2/setup.py
and verify that the logic enforces the correct
input_layout
values whenreturn_softmax_stats
is true.cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (1)
115-132
: LGTM - Clean integration of bmm1_scale parameter.The addition of the
bmm1_scale
parameter to the helper function is well-integrated and maintains consistency with the kernel invocation signature.
/bot run --disable-fail-fast |
PR_Github #14246 [ run ] triggered by Bot |
PR_Github #14246 [ run ] completed with state |
758fe72
to
f6ed76e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (1)
1304-1319
: Fix variable naming and validation logic in max checking loop.The variable names in the max checking loop are misleading and the validation condition appears incorrect:
- Variables are named
sum
andsum_ref
but they're checking max values at even indices- The condition
sum_ref != 1.0f
seems inappropriate for max valuesApply this diff to fix the naming and logic:
// Check the max for (int b_ = 0; b_ < b; ++b_) { for (int s_ = 0; s_ < seqlens[b_]; ++s_) { for (int h_ = 0; h_ < h; ++h_) { uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2; - float sum = out[idx]; - float sum_ref = ref[idx]; - if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01) + float max_val = out[idx]; + float max_ref = ref[idx]; + if (fabsf(max_val - max_ref) / (fabsf(max_val) + fabsf(max_ref)) > 0.01) { n_errors_max++; } } } }
🧹 Nitpick comments (2)
cpp/kernels/fmha_v2/setup.py (2)
2313-2313
: Line length exceeds limit and logic change looks good.The simplification of the softmax stats condition to remove the InputLayout.CONTIGUOUS_Q_KV dependency aligns well with the chunked prefill support requirements.
However, the line exceeds the 120 character limit. Consider breaking it for better readability:
- il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr ' + il_check += ('&& params.softmax_stats_ptr != nullptr ' + if kspec.return_softmax_stats + else '&& params.softmax_stats_ptr == nullptr ')
3654-3801
: Complex but correct implementation of chunked prefill kernel enumeration.The logic correctly distinguishes between normal attention and MLA context phases based on input layout and return_softmax requirements:
- Normal attention: requires CONTIGUOUS_Q_KV when returning softmax stats
- MLA context: requires Q_PAGED_KV or SEPARATE_Q_K_V when returning softmax stats
The kernel specifications are properly configured for different head sizes and the MLA-specific case (192, 128).
Consider improving code readability with better variable naming:
- # for normal attention, we only need contiguous kv as input layout when returning softmax. - skip_combination = return_softmax and (input_layout - != InputLayout.CONTIGUOUS_Q_KV) - # for context mla, we need paged kv or separate qkv as input layout when returning softmax. - skip_mla_combination = return_softmax and ( - input_layout != InputLayout.Q_PAGED_KV - and input_layout != InputLayout.SEPARATE_Q_K_V) + # for normal attention, we only need contiguous kv as input layout when returning softmax. + skip_normal_attention = return_softmax and (input_layout != InputLayout.CONTIGUOUS_Q_KV) + # for context mla, we need paged kv or separate qkv as input layout when returning softmax. + skip_mla_context = return_softmax and ( + input_layout not in [InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V])
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
cpp/kernels/fmha_v2/fmha_test.py
(1 hunks)cpp/kernels/fmha_v2/setup.py
(4 hunks)cpp/kernels/fmha_v2/src/fmha/fragment.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
(1 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
(7 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
(1 hunks)cpp/kernels/fmha_v2/src/softmax_impl.h
(4 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
(2 hunks)cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
(1 hunks)cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
(4 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/kernels/fmha_v2/src/fused_multihead_attention.h
🚧 Files skipped from review as they are similar to previous changes (9)
- cpp/kernels/fmha_v2/fmha_test.py
- cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
- cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
- cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
- cpp/kernels/fmha_v2/src/fmha/fragment.h
- cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
- cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
- cpp/kernels/fmha_v2/src/softmax_impl.h
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (used for checking signness/existence/emptiness), nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do..while, or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with the first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have case-insensitive unique filenames.
All types (including class names) should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by 's' (e.g., sMutableStaticGlobal).
Locally visible static variables should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic-number/literal constants should be uppercase snake case with prefix...
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
cpp/kernels/fmha_v2/setup.py
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
cpp/kernels/fmha_v2/setup.py
🪛 Ruff (0.12.2)
cpp/kernels/fmha_v2/setup.py
2313-2313: Line too long (146 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (4)
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (3)
1297-1301
: LGTM: Function signature change supports paired error reporting.The change to return
std::pair<int, int>
for separate max and sum error counts aligns well with the broader codebase changes to handle softmax statistics as paired values.
1321-1337
: LGTM: Sum validation loop is correctly implemented.The sum checking loop correctly:
- Uses appropriate variable names (
sum
,sum_ref
)- Accesses sum values at odd indices (
h_ * 2 + 1
)- Applies consistent validation logic with the existing condition
1337-1337
: LGTM: Return statement correctly provides both error counts.The return statement properly packages both error counts as a pair, enabling callers to distinguish between max and sum validation failures.
cpp/kernels/fmha_v2/setup.py (1)
1974-1974
: LGTM! Comment accurately reflects chunked prefill support.The updated comment correctly documents that MLA context phase now requires both causal and padding masks to support the new chunked prefill functionality.
ec19f49
to
82754ff
Compare
/bot run --disable-fail-fast |
PR_Github #14747 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (1)
753-753
: Remove unused bmm1_scale (duplicate of prior feedback).This local is never used and was previously flagged as not needed. Please drop it to avoid warnings and noise.
- float bmm1_scale = 1.F / std::sqrt(static_cast<double>(this->mNopeSize + this->mRopeSize));
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
1239-1239
: Align naming with new semantics: these buffers now hold [max,sum], not just sumsThe comment and allocations correctly reflect a 2-float layout, but variable names still say “sum”. To avoid confusion, rename softmax_sum_ref_h/softmax_sum_h to softmax_stats_ref_h/softmax_stats_h.
Apply localized renames:
- // The softmax_stats_d vector is used to store the max/sum of the softmax per token + // The softmax_stats_d vector stores {max,sum} of the softmax per token- float* softmax_sum_ref_h = (float*) malloc(2 * b * s * h * sizeof(float)); - float* softmax_sum_h = (float*) malloc(2 * b * s * h * sizeof(float)); + float* softmax_stats_ref_h = (float*) malloc(2 * b * s * h * sizeof(float)); + float* softmax_stats_h = (float*) malloc(2 * b * s * h * sizeof(float));Also adjust downstream uses outside these hunks as needed (illustrative, not a complete diff):
// reads FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_ref_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); // later reads FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); // check auto errors = check_softmax_results(softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, cu_seqlens); // frees free(softmax_stats_h); free(softmax_stats_ref_h);Also applies to: 1255-1256
🧹 Nitpick comments (7)
cpp/kernels/fmha_v2/setup.py (2)
2313-2313
: Fix Ruff E501 and keep softmax-stats gating readable.The line exceeds 120 chars and was flagged (E501). Also, breaking it improves readability while preserving logic.
- il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr ' + il_check += ( + '&& params.softmax_stats_ptr != nullptr ' + if kspec.return_softmax_stats + else '&& params.softmax_stats_ptr == nullptr ' + )If Ruff is enforced in CI for this path, this change should clear the warning.
3759-3768
: Use comments instead of a stray triple-quoted string literal.This multi-line string is not a docstring (not at top of function) and becomes a no-op statement. Prefer comments to avoid confusion and minor runtime overhead.
- ''' - smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS - + (kv_step * d + kv_step * dv) * kv_buffers) * ele_size - Originally, head size is padded to next_power_of_2<d> and next_power_of_2<dv>. - For fp16/bf16 context MLA (d=192/dv=128), d is padded to 256, and dv remains 128, - if kv_step=64, then smem_size = 160 KB, it is OK but wastes much smem. - if kv_step=128, then smem_size = 256 KB, it is too big for Hopper (228KB smem per SM). - But in fact, 'next multiply of 128 bytes' is needed only, due to TMA 128B swizzle mode. - Then for fp16/bf16 context MLA, d remains 192 (192 * 2 = 128 * 3), and dv remains 128, - if kv_step = 128, then smem_size = 208 KB, smem is fully utilized. - ''' + # smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS + # + (kv_step * d + kv_step * dv) * kv_buffers) * ele_size + # Originally, head size is padded to next_power_of_2<d> and next_power_of_2<dv>. + # For fp16/bf16 context MLA (d=192/dv=128), d is padded to 256, and dv remains 128, + # if kv_step=64, then smem_size = 160 KB, it is OK but wastes much smem. + # if kv_step=128, then smem_size = 256 KB, it is too big for Hopper (228KB smem per SM). + # But only 'next multiple of 128 bytes' is needed due to TMA 128B swizzle. + # With d=192 and dv=128, kv_step=128 yields ~208 KB and fully utilizes SMEM on Hopper.cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu (1)
116-120
: Remove dead code: old split pointers are unused.These local pointers reflect the previous (sum|max) split layout but are never used after the signature change. Safe to delete.
- float* accu_softmax_sum = accu_softmax_stats; - float* accu_softmax_max = accu_softmax_stats + b * s * h; - float* softmax_sum = softmax_stats; - float* softmax_max = softmax_stats + b * s * h;cpp/kernels/fmha_v2/fmha_test.py (1)
186-202
: Good coverage: add save-softmax runs for SM90 BF16 MLA chunked prefill.Both padding and causal mask runs are covered for -paged-kv and -separate-q-k-v. Consider factoring the base command to reduce duplication.
Example helper (pseudo):
def run_ctx(dtype, s, epsilon, layout, extra=''): base = f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} {epsilon} {layout}" subprocess.run(f"{base} {extra}".strip(), shell=True, check=True)cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h (1)
468-479
: Pointer arithmetic matches per-token [max,sum] layout.Both pointers start from the same base, with sum advanced by +sizeof(float), and token/row strides applied via softmax_stats_stride_in_bytes_. This is coherent with a 2-float-per-token layout.
Consider adding a brief comment noting the layout explicitly (token-major; [max, sum]) to aid future maintainers.
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (2)
2036-2037
: Second D2H copy size is consistent; consider naming rename tooSize is correct for {max,sum}. If you adopt the naming suggestion, update identifiers here as well.
2056-2058
: Use compound assignment or logical OR for clarity when aggregating statusBitwise OR between int and bool works but is opaque. Prefer compound OR with an explicit boolean.
Apply:
- status = status | ((errors.first + errors.second) > 0); + status |= ((errors.first > 0) || (errors.second > 0));
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
cpp/kernels/fmha_v2/fmha_test.py
(1 hunks)cpp/kernels/fmha_v2/setup.py
(4 hunks)cpp/kernels/fmha_v2/src/fmha/fragment.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
(2 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
(7 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
(1 hunks)cpp/kernels/fmha_v2/src/softmax_impl.h
(4 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
(2 hunks)cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
(4 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
(1 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(1 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py
(1 hunks)tests/integration/test_lists/test-db/l0_h100.yml
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/kernels/fmha_v2/src/fused_multihead_attention.h
🚧 Files skipped from review as they are similar to previous changes (7)
- tests/integration/test_lists/test-db/l0_h100.yml
- tests/integration/defs/accuracy/test_llm_api_pytorch.py
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
- cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
- cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
- cpp/kernels/fmha_v2/src/fmha/fragment.h
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
cpp/kernels/fmha_v2/fmha_test.py
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
cpp/kernels/fmha_v2/src/softmax_impl.h
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/kernels/fmha_v2/setup.py
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (used for checking signness/existence/emptiness), nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do..while, or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with the first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have case-insensitive unique filenames.
All types (including class names) should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by 's' (e.g., sMutableStaticGlobal).
Locally visible static variables should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic-number/literal constants should be uppercase snake case with prefix...
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
cpp/kernels/fmha_v2/src/softmax_impl.h
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
cpp/kernels/fmha_v2/fmha_test.py
cpp/kernels/fmha_v2/setup.py
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Files:
cpp/kernels/fmha_v2/src/softmax_impl.h
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
🧬 Code Graph Analysis (2)
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (2)
get_size_in_bytes
(195-208)cuda_memcpy_d2h
(1555-1563)
cpp/kernels/fmha_v2/src/softmax_impl.h (1)
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h (3)
void
(113-119)void
(145-151)void
(481-505)
🪛 Ruff (0.12.2)
cpp/kernels/fmha_v2/setup.py
2313-2313: Line too long (146 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (10)
cpp/kernels/fmha_v2/setup.py (2)
1974-1976
: MLA mask selection logic looks correct.For (192,128), disabling sliding/custom while keeping padding and causal enabled matches chunked prefill needs. No action needed.
3769-3801
: Approved: MLA (192×128) softmax metadata verifiedThe generated cubin metadata includes the expected 192×128 softmax entries:
• File: cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
– Line 219: extern void …_192x128_softmax_tma_ws_sm90(…);
– Line 1890: { …, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_softmax_tma_ws_sm90_kernel", …, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90 },You can re-run for confirmation:
rg -n "192x128.*softmax" cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
Everything looks correct.
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu (2)
35-55
: Adopted [max,sum] pair layout correctly in-kernel.The aliasing and offsets match a 2-float-per-token layout:
- max at base[offset], sum at base[offset+1]
- using offset stride of 2 per token is correct.
No action needed.
Also applies to: 74-89
129-131
: API update for reduce4ring_attention launch looks consistent.Passing the consolidated stats buffers and flattened s*h is consistent with the kernel’s indexing.
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h (1)
489-495
: Storing raw sums (not reciprocals) aligns with recover/merge logic.Switching to raw p_sum and scaled max is consistent with the new recover/merge path in ring attention. No action needed.
Please ensure the producer and consumer agree on the layout and meaning:
- fmhaRunner sets softmax_stats_stride_in_bytes to stride-by-token in bytes.
- recoverFromRingAtten uses [max,sum] pairs (confirmed).
If needed, add an assert or comment in both places to prevent future regressions.
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (3)
1951-1951
: Device-to-host copy size updated to 2bs*h; goodMatches the new {max,sum} layout.
891-901
: Refactor MLA detection and clarify-save-softmax
error messagingThe current inline heuristic (
d == 192 && dv == 128
) for detecting the MLA configuration is brittle and should be centralized. Extract it into a clearly named helper (e.g.isMlaConfig(d, dv, sm, dtype)
) alongside your kernel‐enumeration logic. At the same time, the hard‐exit error text can be made more structured and easier to understand.Locations to address:
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
lines 891–901:
Replace the inline
bool is_MLA = (d == 192 && dv == 128);
with a call to your new helper.Update the
fprintf(stderr, …); exit(1);
block to:- bool is_MLA = (d == 192 && dv == 128); + // Heuristic for MLA config. Centralize in isMlaConfig(). + bool is_MLA = isMlaConfig(d, dv, sm_arch, data_type); - fprintf(stderr, - "For normal attention, Only '--contiguous-q-kv' layout supports " - "'-save-softmax'. For MLA only '-paged-kv' and '-separate-q-k-v' layout supports " - "'-save-softmax'.\n"); + fprintf(stderr, + "Invalid layout for -save-softmax.\n" + " • Normal attention: only --contiguous-q-kv is supported.\n" + " • MLA: only --paged-kv or --separate-q-k-v are supported.\n");Please verify:
- Tests and documentation around
-save-softmax
now expect a hard error (exit 1) instead of a silent fallback.- Your new
isMlaConfig
helper is used consistently in fmhaRunner/factory gating.
351-351
: Stride Consistency Verified Across All ConsumersAll instances of softmax‐stats processing now uniformly use a 2-float per-head layout:
- Kernel params (fmhaRunner.cpp) and demo setup (fused_multihead_attention.cpp) set
softmax_stats_stride_in_bytes = sizeof(float) * 2 * h
.- Device fragments (
fmha/fragment.h
,hopper/fragment.h
) index usingstride_in_bytes
and offset by+ sizeof(float)
for the second element.- Softmax implementation (
softmax_impl.h
) computes indices withh * 2
andhi * 2(+1)
.- Host allocations and memcpy in tests and demos allocate
2 * b * s * h
floats.- All gated “–save-softmax” code paths and merged-softmax kernels handle max and sum as pairs.
No remaining hard-coded single-float–per-head assumptions were found.
cpp/kernels/fmha_v2/src/softmax_impl.h (2)
1009-1011
: Initialize and pass max_fp32 for reduction — correct
- Initialization to -HUGE_VALF is appropriate.
- The reduce call now produces both sum_fp32 and max_fp32 used for stats and normalization.
154-156
: Confirmed single call site and correct max-out parameter usage
- Definitions at cpp/kernels/fmha_v2/src/softmax_impl.h lines 154–156, 250–252, 418–421
- Only invocation at line 1010:
reduce(data_fp32, mask_, params.warps_n, sum_fp32, max_fp32, attention_sink);
- No other calls of this six-argument
reduce
overload foundMasks are applied before computing max; warp-wide reduction propagates
max_fp32
correctly; normalization and per-row stats both use the same persisted max. Everything looks consistent—no further changes needed.
PR_Github #14747 [ run ] completed with state |
82754ff
to
15f4c3b
Compare
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (1)
1297-1319
: Fix variable naming and remove incorrect guard in max check.The loop is checking the "max" statistic but uses variables named sum/sum_ref and applies an unnecessary
sum_ref != 1.0f
guard. This is misleading and can skip legitimate discrepancies.
- Rename to maxVal/maxRef for clarity.
- Remove the
!= 1.0f
condition.- Use a robust error metric with an epsilon fallback (consistent with check_results).
Apply this diff:
- // Check the max + // Check the max for (int b_ = 0; b_ < b; ++b_) { for (int s_ = 0; s_ < seqlens[b_]; ++s_) { for (int h_ = 0; h_ < h; ++h_) { uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2; - float sum = out[idx]; - float sum_ref = ref[idx]; - if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01) + float maxVal = out[idx]; + float maxRef = ref[idx]; + float den = fabsf(maxVal) + fabsf(maxRef); + float relErr = den <= 1e-6f ? fabsf(maxVal - maxRef) : fabsf(maxVal - maxRef) / den; + if (relErr > 0.01f) { n_errors_max++; } } } }cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
1239-1243
: Rename “sum” buffers to “stats” to reflect [max,sum] layout and avoid confusionThe host buffers now hold both max and sum per token, but variables are still named softmax_sum_* which is misleading. Rename them to softmax_stats_* consistently.
Apply this diff:
- // The softmax_stats_d vector is used to store the max/sum of the softmax per token + // The softmax_stats_d vector stores [max, sum] of the softmax per token- float* softmax_sum_ref_h = (float*) malloc(2 * b * s * h * sizeof(float)); - float* softmax_sum_h = (float*) malloc(2 * b * s * h * sizeof(float)); + float* softmax_stats_ref_h = (float*) malloc(2 * b * s * h * sizeof(float)); + float* softmax_stats_h = (float*) malloc(2 * b * s * h * sizeof(float));- FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_sum_ref_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); + FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_ref_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));- FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_sum_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); + FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));- auto errors = check_softmax_results(softmax_sum_h, softmax_sum_ref_h, b, s, h, seqlens, cu_seqlens); + auto errors = check_softmax_results(softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, cu_seqlens);And update frees accordingly:
- free(softmax_sum_h); - free(softmax_sum_ref_h); + free(softmax_stats_h); + free(softmax_stats_ref_h);Also applies to: 1255-1257, 1951-1951, 2036-2037, 2056-2058
🧹 Nitpick comments (9)
cpp/kernels/fmha_v2/setup.py (1)
2314-2314
: Line too long (Ruff E501); wrap softmax-stats predicate for readability.The inline ternary makes this line exceed the 120-char guideline flagged by Ruff. Extract or wrap the clause.
Apply this diff to split the predicate:
- il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr ' + softmax_ptr_check = ( + '&& params.softmax_stats_ptr != nullptr ' + if kspec.return_softmax_stats + else '&& params.softmax_stats_ptr == nullptr ' + ) + il_check += softmax_ptr_checkcpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (2)
902-906
: Clarify FP8 error message“Currently fp8 kernel doesn't support fp8.” is confusing. Clarify that save-softmax is unsupported for FP8.
Apply this diff:
- fprintf(stderr, "Currently fp8 kernel doesn't support fp8.\n"); + fprintf(stderr, "-save-softmax is not supported for FP8 (E4M3) at the moment.\n");
2-2
: Update copyright year to current yearGuideline requires the current year in headers. Please include 2025.
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu (6)
35-39
: Fix shape comments to match flattened indexingKernel indexes flatten [s_block, h] into a single dimension (s_block * h). The comments currently say “b x s_block x h x d” and “b x s_block x h x 2”, which is misleading. Update comments to reflect the flattened layout or adjust indexing accordingly.
Apply this diff:
- Tout* __restrict__ accu_output, // b x s_block x h x d - float* __restrict__ accu_softmax_stats, // b x s_block x h x 2 (max/sum) + Tout* __restrict__ accu_output, // b x (s_block * h) x d (flattened [s_block, h]) + float* __restrict__ accu_softmax_stats, // b x (s_block * h) x 2 (max/sum)- Tout* __restrict__ output, // b x s_block x h x d - float* __restrict__ softmax_stats, // b x s_block x h x 2 (max/sum) + Tout* __restrict__ output, // b x (s_block * h) x d (flattened [s_block, h]) + float* __restrict__ softmax_stats, // b x (s_block * h) x 2 (max/sum)
49-55
: Pointer aliasing is subtle; prefer explicit [base + 0/1] indexing for readabilityShifting the base pointer by +1 and then using base indices that are multiples of 2 works but is error-prone. Index explicitly into the 2-wide tuple to avoid off-by-one mistakes.
Apply this diff:
- float* accu_softmax_sum = accu_softmax_stats + 1; - float* accu_max = accu_softmax_stats; - float* softmax_sum = softmax_stats + 1; - float* max = softmax_stats; + // Access pattern: stats[base + 0] -> max, stats[base + 1] -> sum- uint64_t lm_start_offset_ = lm_start_offset + s_ * 2; - float my_accu_ss = accu_softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : accu_softmax_sum[lm_start_offset_]; - float my_ss = softmax_sum[lm_start_offset_] == 0.0 ? 1.0 : softmax_sum[lm_start_offset_]; - float cur_max = (accu_max[lm_start_offset_] > max[lm_start_offset_]) ? accu_max[lm_start_offset_] - : max[lm_start_offset_]; - float scale1 = exp(accu_max[lm_start_offset_] - cur_max); - float scale2 = exp(max[lm_start_offset_] - cur_max); + uint64_t base = lm_start_offset + s_ * 2; + float my_accu_ss = (accu_softmax_stats[base + 1] == 0.0f) ? 1.0f : accu_softmax_stats[base + 1]; + float my_ss = (softmax_stats[base + 1] == 0.0f) ? 1.0f : softmax_stats[base + 1]; + float cur_max = fmaxf(accu_softmax_stats[base + 0], softmax_stats[base + 0]); + float scale1 = expf(accu_softmax_stats[base + 0] - cur_max); + float scale2 = expf(softmax_stats[base + 0] - cur_max); float cur_softmax_sum = my_accu_ss * scale1 + my_ss * scale2; if (cur_softmax_sum == 0) cur_softmax_sum = 1.0; scaled_my_ss1_ = scale1 * my_accu_ss / cur_softmax_sum; scaled_my_ss2_ = scale2 * my_ss / cur_softmax_sum; - accu_softmax_sum[lm_start_offset_] = cur_softmax_sum; - accu_max[lm_start_offset_] = cur_max; + accu_softmax_stats[base + 1] = cur_softmax_sum; + accu_softmax_stats[base + 0] = cur_max;Also applies to: 74-89
80-81
: Use expf for float math on deviceexp promotes to double and back; prefer expf for float to avoid unnecessary double-precision ops.
Apply this diff (already included in the previous refactor; if you keep the current indexing, still apply expf):
- float scale1 = exp(accu_max[lm_start_offset_] - cur_max); - float scale2 = exp(max[lm_start_offset_] - cur_max); + float scale1 = expf(accu_max[lm_start_offset_] - cur_max); + float scale2 = expf(max[lm_start_offset_] - cur_max);
112-120
: Remove dead code in invokeRecoverFromRAaccu_softmax_sum/max and softmax_sum/max are computed but never used. Drop them to reduce confusion.
Apply this diff:
- float* accu_softmax_sum = accu_softmax_stats; - float* accu_softmax_max = accu_softmax_stats + b * s * h; - float* softmax_sum = softmax_stats; - float* softmax_max = softmax_stats + b * s * h;
31-41
: Consider removing unused cu_seqlens or using it to guard s_len per-batchcu_seqlens is passed to the kernel but not used. Either remove it from the signature and launch or use it to bound s_len to avoid touching invalid positions when sequences vary.
Would you like me to propagate a signature cleanup (kernel and header) removing cu_seqlens, or to wire it into the loop to guard per-batch lengths?
2-2
: Update copyright yearPer guidelines, include the current year (2025) in the header.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
cpp/kernels/fmha_v2/fmha_test.py
(1 hunks)cpp/kernels/fmha_v2/setup.py
(4 hunks)cpp/kernels/fmha_v2/src/fmha/fragment.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
(2 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
(7 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
(1 hunks)cpp/kernels/fmha_v2/src/softmax_impl.h
(4 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
(2 hunks)cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
(4 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
(1 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(1 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py
(1 hunks)tests/integration/test_lists/test-db/l0_h100.yml
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (11)
- cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
- cpp/kernels/fmha_v2/src/fused_multihead_attention.h
- cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
- tests/integration/test_lists/test-db/l0_h100.yml
- cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
- cpp/kernels/fmha_v2/fmha_test.py
- tests/integration/defs/accuracy/test_llm_api_pytorch.py
- cpp/kernels/fmha_v2/src/fmha/fragment.h
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
- cpp/kernels/fmha_v2/src/softmax_impl.h
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
cpp/kernels/fmha_v2/setup.py
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (used for checking signness/existence/emptiness), nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do..while, or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with the first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have case-insensitive unique filenames.
All types (including class names) should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by 's' (e.g., sMutableStaticGlobal).
Locally visible static variables should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic-number/literal constants should be uppercase snake case with prefix...
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
cpp/kernels/fmha_v2/setup.py
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
🧬 Code Graph Analysis (1)
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (2)
get_size_in_bytes
(195-208)cuda_memcpy_d2h
(1555-1563)
🪛 Ruff (0.12.2)
cpp/kernels/fmha_v2/setup.py
2313-2313: Line too long (146 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (4)
cpp/kernels/fmha_v2/setup.py (3)
3654-3734
: Save-softmax gating logic for normal vs. MLA paths looks correct.
- For return_softmax=True:
- Normal attention is limited to CONTIGUOUS_Q_KV via skip_combination. Good.
- MLA-specific save-softmax is handled separately (see below). Good.
This separation avoids the pitfall of over-gating contiguous layouts when combining flags. No action needed.
3769-3801
: Verify 192×128 MLA kernels are generated and exercisedI confirmed the 192/128 spec in cpp/kernels/fmha_v2/setup.py (lines 3769–3801), but I didn’t find any generated CUDA kernels matching
192x128
+_softmax
forQ_PAGED_KV
orQ_K_V
. Please:
- Run your codegen step locally and re-search the
generated/*.cu
artifacts for192x128.*_softmax.*(q_paged_kv|q_k_v)
.- At runtime, ensure:
input_layout
is set toQ_PAGED_KV
orSEPARATE_Q_K_V
whenreturn_softmax=True
.- The
flash_attention && warp_specialization
path is taken so this spec is discoverable.
1974-1976
: Confirm mask type for MLA chunked prefill (must use PADDING, not SLIDING_OR_CHUNKED_CAUSAL)
I wasn’t able to find any runtime setting ofattention_mask_type
for MLA chunked prefill—noSLIDING_OR_CHUNKED_CAUSAL
orPADDING
enum usages showed up in the codebase search. Please manually verify that whenenable_chunked_prefill
is true and(head_size, head_size_v) == (192,128)
, the runtime path uses the PADDING mask (and never SLIDING_OR_CHUNKED_CAUSAL). Otherwise the kernels generated incpp/kernels/fmha_v2/setup.py
won’t match the requested mask type.• cpp/kernels/fmha_v2/setup.py around lines 1974–1976 (check
sliding_or_chunked_causal_mask = '0'
)
• tensorrt_llm/runtime/model_runner_cpp.py (whereenable_chunked_context
is passed)
• tensorrt_llm/_torch/attention_backend/trtllm.py (theuse_paged_context_fmha
logic)cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
351-352
: Softmax stats stride correctly updated to two floats per headUsing get_size_in_bytes(h * 2, DATA_TYPE_FP32) aligns with the new [max, sum] per-token layout. No concerns.
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (1)
891-901
: Avoid magic numbers and snake_case; use constexpr and camelCase for MLA detectionReplace hard-coded 192/128 and is_MLA with constexprs and isMla. It improves readability and aligns with naming guidelines.
Apply this diff:
- bool is_MLA = (d == 192 && dv == 128); - if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) - || (is_MLA && input_layout != Attention_input_layout::Q_PAGED_KV + constexpr size_t kMlaQd = 192; + constexpr size_t kMlaDv = 128; + bool isMla = (d == kMlaQd && dv == kMlaDv); + if (((!isMla) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) + || (isMla && input_layout != Attention_input_layout::Q_PAGED_KV && input_layout != Attention_input_layout::SEPARATE_Q_K_V)) { - fprintf(stderr, - "For normal attention, Only '--contiguous-q-kv' layout supports " - "'-save-softmax'. For MLA only '-paged-kv' and '-separate-q-k-v' layout supports " - "'-save-softmax'.\n"); + fprintf(stderr, + "Unsupported layout for -save-softmax. Normal attention requires '--contiguous-q-kv'. " + "MLA requires either '--paged-kv' or '--separate-q-k-v'.\n"); exit(1); }
🧹 Nitpick comments (4)
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (4)
1239-1239
: Clarify the layout/order in the commentCall out the exact host-visible shape and order to avoid ambiguity.
- // The softmax_stats_d vector is used to store the max/sum of the softmax per token + // softmax_stats_d stores two FP32 values per token per head in [max, sum] order. + // Host-visible logical shape: [total_q_tokens, h, 2]; per-token stride = h * 2 floats.
1255-1257
: Prefer RAII containers (std::vector) for host buffers; consider pinned memory for transfer efficiencyTo avoid manual lifetime management and potential leaks, use std::vector. For large copies, switching to cudaMallocHost (pinned) can reduce D2H latency.
Minimal RAII change at allocation sites:
- float* softmax_stats_ref_h = (float*) malloc(2 * b * s * h * sizeof(float)); - float* softmax_stats_h = (float*) malloc(2 * b * s * h * sizeof(float)); + std::vector<float> softmax_stats_ref_h(2 * b * s * h); + std::vector<float> softmax_stats_h(2 * b * s * h);And update call sites (outside this hunk) to pass data():
// around Line 1951 FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_ref_h.data(), softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); // around Line 2036 FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h.data(), softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); // around Line 2056 auto errors = check_softmax_results(softmax_stats_h.data(), softmax_stats_ref_h.data(), b, s, h, seqlens, cu_seqlens); // around Lines 2152-2154: remove the corresponding free() callsAlternatively, if you want pinned host buffers for faster copies:
float* softmax_stats_ref_h = nullptr; float* softmax_stats_h = nullptr; cudaMallocHost(&softmax_stats_ref_h, 2 * b * s * h * sizeof(float)); cudaMallocHost(&softmax_stats_h, 2 * b * s * h * sizeof(float)); // ... cudaFreeHost(softmax_stats_h); cudaFreeHost(softmax_stats_ref_h);
2036-2057
: Guard the softmax_stats D2H copy and preserve error counts
- Avoid an unnecessary D2H copy when -save-softmax is disabled.
- Preserve the numeric error count when combining with status.
- FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); + if (save_softmax) + { + FMHA_CHECK_CUDA(cuda_memcpy_d2h(softmax_stats_h, softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32)); + } @@ - if (save_softmax) - { - auto errors = check_softmax_results(softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, cu_seqlens); - status = status | ((errors.first + errors.second) > 0); - } + if (save_softmax) + { + auto errors = check_softmax_results( + softmax_stats_h, softmax_stats_ref_h, b, s, h, seqlens, cu_seqlens); + status |= static_cast<int>(errors.first + errors.second); + }
2152-2154
: Remove manual frees if you adopt RAII or pinned-host refactor; otherwise current frees are correctIf you switch to std::vector for host buffers, drop these frees; if you switch to pinned memory, free with cudaFreeHost instead.
- free(softmax_stats_h); - free(softmax_stats_ref_h); + // Not needed if using std::vector for host buffers. + // If using pinned host memory, call cudaFreeHost(...) here instead.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
(8 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
(0 hunks)
💤 Files with no reviewable changes (1)
- cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🧰 Additional context used
📓 Path-based instructions (2)
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (used for checking signness/existence/emptiness), nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do..while, or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with the first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have case-insensitive unique filenames.
All types (including class names) should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by 's' (e.g., sMutableStaticGlobal).
Locally visible static variables should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic-number/literal constants should be uppercase snake case with prefix...
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
🔇 Additional comments (2)
cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (2)
351-351
: Softmax stats stride now reflects two floats per head — LGTMUsing get_size_in_bytes(h * 2, DATA_TYPE_FP32) matches the [max, sum] layout and is consistent with the doubled footprint.
1951-1951
: D2H copy for doubled softmax stats — LGTMCopying 2 * b * s * h elements with DATA_TYPE_FP32 matches the new [max, sum] layout.
PR_Github #14883 [ run ] triggered by Bot |
/bot run --disable-fail-fast |
PR_Github #14920 [ run ] triggered by Bot |
PR_Github #14883 [ run ] completed with state |
PR_Github #14920 [ run ] completed with state |
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
Signed-off-by: Mingyang Jiang <[email protected]>
48590c2
to
f3fbb76
Compare
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (2)
1303-1316
: Misleading variable names and incorrect guard in the max-check loopYou’re checking the "max" plane but naming the variables sum/sum_ref and skipping when sum_ref == 1.0f. That guard is inappropriate for max checks and risks hiding real mismatches. Rename and use a robust error metric (with epsilon fallback).
Apply this diff:
- uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2; - float sum = out[idx]; - float sum_ref = ref[idx]; - if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01) + uint64_t idx = (cu_seqlens[b_] + s_) * h * 2 + h_ * 2; + float max_val = out[idx]; + float max_ref = ref[idx]; + float den = fabsf(max_val) + fabsf(max_ref); + float relErr = den <= 1e-6f ? fabsf(max_val - max_ref) + : fabsf(max_val - max_ref) / den; + if (relErr > 0.01f) { n_errors_max++; }
1327-1333
: Softmax sum check should not skip 1.0f; use robust error metricConditioning on sum_ref != 1.0f skips exactly the common case where sums are 1.0f, hiding defects. Use the same epsilon-aware metric as check_results for consistency.
Apply this diff:
- float sum = out[idx]; - float sum_ref = ref[idx]; - if (sum_ref != 1.0f && fabsf(sum - sum_ref) / (fabsf(sum) + fabsf(sum_ref)) > 0.01) + float sum = out[idx]; + float sum_ref = ref[idx]; + float den = fabsf(sum) + fabsf(sum_ref); + float relErr = den <= 1e-6f ? fabsf(sum - sum_ref) + : fabsf(sum - sum_ref) / den; + if (relErr > 0.01f) { n_errors_sum++; }
🧹 Nitpick comments (6)
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h (1)
468-479
: Clarify and harden softmax stats pointer arithmetic for the [max, sum] layoutThe new layout and pointer stepping look correct, but readability and safety can be improved by working in float elements instead of raw bytes and documenting the [max, sum] pair per token explicitly. Also, assert that stride is a multiple of sizeof(float).
Apply this diff to simplify indexing and make intent explicit:
- , softmax_sum_ptr_(reinterpret_cast<char*>(params.softmax_stats_ptr)) - , softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) + , softmax_stats_stride_in_bytes_(params.softmax_stats_stride_in_bytes) { - softmax_max_ptr_ = reinterpret_cast<char*>(params.softmax_stats_ptr); + // Treat the stats buffer as float-typed; each token stores [max, sum] as two floats. + // Convert byte stride to float stride for typed pointer arithmetic. + auto* stats_base = reinterpret_cast<float*>(params.softmax_stats_ptr); + int const stride_floats = static_cast<int>(softmax_stats_stride_in_bytes_) / static_cast<int>(sizeof(float)); + // Defensive: stride must align to float elements. + assert((softmax_stats_stride_in_bytes_ % sizeof(float)) == 0); + int warp = (threadIdx.x % 128) / Cta_tile::THREADS_PER_WARP; int lane = threadIdx.x % Cta_tile::THREADS_PER_WARP; // MMA row0 index (8x4 thread layout) row0_ = warp * Mma_tile::M_PER_MMA / WARPS_M + (lane / 4); int sum_s = params.is_s_padded ? params.s * head_info.bidb : params.cu_q_seqlens[head_info.bidb]; int token_id = sum_s * params.h + head_info.bidh; - size_t const bh_offset = token_id * sizeof(float) * 2 + local_q_tile_offset_ * softmax_stats_stride_in_bytes_; - softmax_max_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_; - softmax_sum_ptr_ += bh_offset + row0_ * softmax_stats_stride_in_bytes_ + sizeof(float); + // Each token stores 2 floats: [max, sum] + size_t const per_token_floats = 2; + size_t const bh_offset_floats = token_id * per_token_floats + static_cast<size_t>(local_q_tile_offset_) * stride_floats; + // Row pointers for the current tile row + softmax_max_ptr_f_ = stats_base + bh_offset_floats + static_cast<size_t>(row0_) * stride_floats + 0; // max at +0 + softmax_sum_ptr_f_ = stats_base + bh_offset_floats + static_cast<size_t>(row0_) * stride_floats + 1; // sum at +1 };Outside this hunk, adjust member types:
- char* softmax_sum_ptr_ = nullptr; - char* softmax_max_ptr_ = nullptr; + float* softmax_sum_ptr_f_ = nullptr; + float* softmax_max_ptr_f_ = nullptr;And in store() use float* pointers (see related change below).
cpp/kernels/fmha_v2/setup.py (1)
2310-2316
: Break long conditional line to satisfy linters and improve readabilityRuff flagged E501 (>120 chars). Split the softmax pointer gating across lines.
Apply this diff:
- il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr ' + il_check += ( + '&& params.softmax_stats_ptr != nullptr ' + if kspec.return_softmax_stats + else '&& params.softmax_stats_ptr == nullptr ' + )cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu (4)
51-55
: Two-float interleaved stats indexing is subtle—consider documenting or wrappingUsing base+1 pointers with even indices is correct for [max, sum] interleaving, but it’s easy to regress. A small comment or helper to compute “index for sum/max of token t” would improve maintainability.
Add a brief comment:
- float* accu_softmax_sum = accu_softmax_stats + 1; + // Layout: for token t, stats[2*t + 0] = max, stats[2*t + 1] = sum + float* accu_softmax_sum = accu_softmax_stats + 1; // even idx addresses (2*t) map to sum via base+1
74-89
: Use expf (or __expf) for device float math and guard zero with epsilonOn device, exp promotes to double and is slower. Prefer expf (or __expf for fast-math). Also, compare sums against an epsilon to avoid dividing by zero.
Apply this diff:
- float scale1 = exp(accu_max[lm_start_offset_] - cur_max); - float scale2 = exp(max[lm_start_offset_] - cur_max); + float scale1 = expf(accu_max[lm_start_offset_] - cur_max); + float scale2 = expf(max[lm_start_offset_] - cur_max); float cur_softmax_sum = my_accu_ss * scale1 + my_ss * scale2; - if (cur_softmax_sum == 0) - cur_softmax_sum = 1.0; + if (fabsf(cur_softmax_sum) < 1e-6f) { + cur_softmax_sum = 1.0f; + }
112-120
: Remove unused local aliases to avoid confusionaccu_softmax_sum/accu_softmax_max/softmax_sum/softmax_max are computed but never used. This can mislead readers about the intended layout.
Apply this diff:
- float* accu_softmax_sum = accu_softmax_stats; - float* accu_softmax_max = accu_softmax_stats + b * s * h; - float* softmax_sum = softmax_stats; - float* softmax_max = softmax_stats + b * s * h;
56-61
: Barrier usage appears unnecessary hereYou initialize a cuda::barrier and only call arrive_and_wait() once at the end. A simple __syncthreads() (or removing the barrier entirely) would suffice; the final sync isn’t needed since the kernel returns immediately.
Replace init/sync with a regular __syncthreads(), or remove both if no cross-thread dependency remains at the end.
Also applies to: 108-110
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
cpp/kernels/fmha_v2/fmha_test.py
(1 hunks)cpp/kernels/fmha_v2/setup.py
(4 hunks)cpp/kernels/fmha_v2/src/fmha/fragment.h
(3 hunks)cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
(2 hunks)cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
(8 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention.h
(1 hunks)cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
(1 hunks)cpp/kernels/fmha_v2/src/softmax_impl.h
(4 hunks)cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
(2 hunks)cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
(4 hunks)cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
(0 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(1 hunks)tests/integration/defs/accuracy/test_llm_api_pytorch.py
(1 hunks)tests/integration/test_lists/test-db/l0_h100.yml
(1 hunks)
💤 Files with no reviewable changes (1)
- cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🚧 Files skipped from review as they are similar to previous changes (10)
- cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
- cpp/kernels/fmha_v2/src/fused_multihead_attention.h
- cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
- cpp/kernels/fmha_v2/src/softmax_impl.h
- cpp/kernels/fmha_v2/src/fmha/fragment.h
- tests/integration/test_lists/test-db/l0_h100.yml
- cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
- cpp/kernels/fmha_v2/fmha_test.py
- tests/integration/defs/accuracy/test_llm_api_pytorch.py
🧰 Additional context used
📓 Path-based instructions (6)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}
: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Parameter names must be consistent between declarations and definitions
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
**/*.{h,hpp,hxx,hh,cuh}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
cpp/kernels/fmha_v2/setup.py
**/*.{cpp,cxx,cc,cu}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,cxx,cc,cu}
: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)
Files:
cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
cpp/kernels/fmha_v2/setup.py
🪛 Ruff (0.12.2)
cpp/kernels/fmha_v2/setup.py
2313-2313: Line too long (146 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (4)
cpp/kernels/fmha_v2/src/fused_multihead_attention_utils.h (1)
1297-1302
: Good API refactor to return separate max/sum error countsReturning std::pair<int, int> for {n_errors_max, n_errors_sum} cleanly separates the validation concerns and aligns with the new [max, sum] layout. No issues spotted here.
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h (1)
489-505
: Confirm whether max should be stored scaled by 1/sqrt(d)You store sum unscaled but save max as p_max[lane % 2] / sqrt_d. Is this intentional for the new [max, sum] stats semantics? If the consumer expects the raw max (pre 1/sqrt(d) scaling), this will corrupt stats.
If scaling is unintended, change to store the raw max:
- values = p_max[lane % 2] / sqrt_d; + values = p_max[lane % 2];And update the target pointers to typed ones if you applied the refactor:
- char* dst_ptr = (lane % 4 < 2) ? softmax_sum_ptr_ : softmax_max_ptr_; + float* dst_ptr = (lane % 4 < 2) ? softmax_sum_ptr_f_ : softmax_max_ptr_f_;cpp/kernels/fmha_v2/setup.py (2)
1974-1977
: Mask selection for MLA 192/128 looks correctLimiting to padding and causal masks for chunked prefill is consistent with the MLA context requirements.
3654-3662
: Kernel enumeration gating for return_softmax stats: logic aligns with normal vs. MLA contexts
- Normal attention: guarded by skip_combination to only emit CONTIGUOUS_Q_KV when returning softmax.
- MLA context (192/128): separately guarded by skip_mla_combination to only emit Q_PAGED_KV or SEPARATE_Q_K_V.
This separation looks intentional and correct for the two use-cases.
Also applies to: 3769-3801
PR_Github #15067 [ run ] triggered by Bot |
/bot run --disable-fail-fast |
PR_Github #15085 [ run ] triggered by Bot |
PR_Github #15067 [ run ] completed with state |
PR_Github #15085 [ run ] completed with state |
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Description
add support for Hopper MLA chunked prefill
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.