-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[None][feat] : Add FP8 context MLA support for SM120 #6059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/bot run |
PR_Github #11939 [ run ] triggered by Bot |
PR_Github #11939 [ run ] completed with state |
7eee4ec
to
4339442
Compare
📝 Walkthrough## Walkthrough
Support for a new boolean flag, `mFP8ContextMLA`, was introduced to enable FP8 context mode for Multi-Head Linear Attention (MLA) alongside the existing FMHA mode. This required updates to buffer size calculations, parameter passing, kernel launches, and conditional logic across attention operator implementations, kernel headers, and related tests. Additionally, a new CUDA kernel was added to quantize input data to FP8 format within MLA context.
## Changes
| File(s) | Change Summary |
|-----------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|
| **Attention Operator Source**<br>`cpp/tensorrt_llm/common/attentionOp.cpp`, `cpp/tensorrt_llm/thop/attentionOp.cpp` | Integrated `mFP8ContextMLA` flag into buffer size, parameter, and control flow logic; added robust checks for optional tensors; set MLA FP8 context flags during generation and enqueue phases |
| **Attention Operator Header**<br>`cpp/tensorrt_llm/common/attentionOp.h` | Added private member variable `mFP8ContextMLA` to `AttentionOp` class |
| **MLA Kernels Source and Header**<br>`cpp/tensorrt_llm/kernels/mlaKernels.cu`, `cpp/tensorrt_llm/kernels/mlaKernels.h` | Added `QuantizeCopyInputToFp8Kernel` CUDA kernel; extended `MlaParams` struct with `quant_scale_qkv`; updated MLA kernel invocation to launch FP8 quantization kernel conditionally |
| **Unit Tests**<br>`tests/unittest/_torch/test_attention_mla.py` | Increased FP8 accuracy tolerances in test dictionary |
## Sequence Diagram(s)
```mermaid
sequenceDiagram
participant User
participant AttentionOp
participant MLA Kernel
participant CUDA Device
User->>AttentionOp: Request context enqueue (MLA, FP8 enabled)
AttentionOp->>MLA Kernel: invokeMLARopeContext(params, ...)
MLA Kernel->>CUDA Device: applyMLARopeAndAssignQKVKernelOptContext
alt FP8 Context MLA enabled
MLA Kernel->>CUDA Device: QuantizeCopyInputToFp8Kernel (quantize input to FP8)
end
MLA Kernel-->>AttentionOp: Return
AttentionOp-->>User: Complete Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (1)
cpp/tensorrt_llm/common/attentionOp.cpp (1)
749-749
: Address the existing review comment - unit tests are still missing for this code path.QiJune previously requested unit tests for this FP8 context MLA path, which haven't been added yet.
🧹 Nitpick comments (3)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
926-968
: Remove the explicitcudaStreamSynchronize
& drop the unusedheadDim
cudaStreamSynchronize(stream);
forces a host/device sync for every context-level call, nullifying any overlap with subsequent kernels and hurting throughput.
Nothing in this path depends on a hard sync; the earliersync_check_cuda_error(stream)
is enough.
➜ Delete the sync or make it optional behind a debug flag.
size_t headDim = …;
is never used – will trigger “set but not used” warnings when compiling with-Wall
.
➜ Remove the variable.cpp/tensorrt_llm/common/attentionOp.cpp (2)
732-754
: Consider improving code clarity and consistency.The buffer size calculation logic is correct, but could benefit from:
- Consistent naming convention (e.g.,
dim_*_per_head
vstotal_*_dim_all_heads
)- Adding a comment explaining why MLA requires different buffer size calculation
- int const num_total_qkv_elements - = max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + // MLA uses different head dimensions, requiring custom buffer size calculation + int const total_qkv_elements + = max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
277-2887
: Verify the performance improvements claimed in the PR.The PR claims ~24% TTFT improvement on SM120. Please ensure performance benchmarks are included in the test suite to validate these gains and prevent regressions.
Would you like me to help create a performance test framework to track these metrics?
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
cpp/tensorrt_llm/common/attentionOp.cpp
(9 hunks)cpp/tensorrt_llm/common/attentionOp.h
(1 hunks)cpp/tensorrt_llm/kernels/mlaKernels.cu
(2 hunks)cpp/tensorrt_llm/kernels/mlaKernels.h
(2 hunks)cpp/tensorrt_llm/thop/attentionOp.cpp
(4 hunks)tests/unittest/_torch/test_attention_mla.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
cpp/tensorrt_llm/thop/attentionOp.cpp (2)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)
host_kv_cache_pool_mapping
(586-590)host_kv_cache_pool_pointers
(579-583)cpp/tensorrt_llm/common/attentionOp.cpp (1)
T
(1285-2054)
🔇 Additional comments (12)
cpp/tensorrt_llm/common/attentionOp.h (1)
384-384
: LGTM!The addition of the
mFP8ContextMLA
flag follows the established naming convention and is properly initialized.tests/unittest/_torch/test_attention_mla.py (1)
342-342
: Tolerance adjustment for FP8 precision looks reasonable.The minor increase in tolerance values (1.875%) is appropriate for FP8 operations with the new FP8 context MLA support.
cpp/tensorrt_llm/kernels/mlaKernels.h (2)
90-91
: LGTM!The addition of
quant_scale_qkv
pointer is properly typed and initialized, following the existing pattern for quantization scale pointers.
116-119
: LGTM!The kernel declaration follows proper CUDA conventions with appropriate template parameter for input flexibility and fixed FP8 output type.
cpp/tensorrt_llm/thop/attentionOp.cpp (6)
152-155
: Good defensive programming practice.Adding the null check before accessing the optional tensor prevents potential crashes.
210-212
: LGTM!Proper fallback to
attention_window_size
whencache_indirection
is not available.
218-233
: Excellent robustness improvements.All the added checks for KV cache related optional tensors follow a consistent pattern and prevent null pointer dereferences.
254-259
: LGTM!The compound check ensures both quant mode and tensor availability before pointer assignment.
261-266
: Good consistency in output scale handling.The checks properly differentiate between FP8 and FP4 output scale scenarios.
332-334
: LGTM!Consistent with the pattern of checking optional tensors before access.
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
931-944
: Sanity-check the per-head dimensional assumptions
total_k_dim_all_heads
andtotal_v_dim_all_heads
are computed withhead_num
, implicitly assuming
num_kv_heads == head_num
.
For MQA/GQA models that use 1 or 2 KV-heads this will over-estimate the element count and overrun the
destination buffer.Please confirm that:
total_k_dim_all_heads = effective_num_kv_heads * dim_k_per_head; total_v_dim_all_heads = effective_num_kv_heads * dim_v_per_head;
or guard with an assertion.
cpp/tensorrt_llm/common/attentionOp.cpp (1)
2623-2623
: Verify XQA enablement conditions for SM120.The XQA is enabled only for SM120 with generation MLA. Please confirm:
- Is XQA support exclusive to SM120, or should it also include SM121?
- Are there any other conditions that should be checked (e.g., FP8 mode)?
/bot run |
PR_Github #12430 [ run ] triggered by Bot |
PR_Github #12430 [ run ] completed with state |
4339442
to
711a421
Compare
/bot run |
PR_Github #12629 [ run ] triggered by Bot |
PR_Github #12629 [ run ] completed with state |
/bot run |
PR_Github #12693 [ run ] triggered by Bot |
PR_Github #12693 [ run ] completed with state |
/bot run |
PR_Github #12777 [ run ] triggered by Bot |
PR_Github #12777 [ run ] completed with state |
711a421
to
606608e
Compare
/bot run |
PR_Github #12910 [ run ] triggered by Bot |
PR_Github #12910 [ run ] completed with state |
/bot run |
PR_Github #13323 [ run ] triggered by Bot |
PR_Github #13323 [ run ] completed with state |
606608e
to
e8d5dcf
Compare
/bot run |
/bot run |
PR_Github #13760 [ run ] triggered by Bot |
PR_Github #13761 [ run ] triggered by Bot |
PR_Github #13760 [ run ] completed with state |
PR_Github #13761 [ run ] completed with state |
234c2e1
to
c03bae2
Compare
/bot run |
PR_Github #13954 [ run ] triggered by Bot |
PR_Github #13954 [ run ] completed with state |
Signed-off-by: peaceh <[email protected]>
Signed-off-by: peaceh <[email protected]>
c03bae2
to
e1f611c
Compare
/bot run |
PR_Github #14077 [ run ] triggered by Bot |
PR_Github #14077 [ run ] completed with state |
Signed-off-by: peaceh <[email protected]>
Description
Add FP8 context MLA support for SM120
Compared to FP8 context FMHA, FP8 context MLA needs BF16 output.
Accuracy:
GPQA Diamond score for DeepSeek-R1 quant wo_gemm ckpt + FP8 context MLA + FP8 xqa-mla gen is 0.707 on SM120, which is reasonable since the baseline : nvfp4 DeepSeek-R1 ckpt + BF16 MLA + FP8 MLA is 0.702 on B200
Performance:
TTFT improvement on SM120:
BF16 context MLA : 121244.7586ms
FP8 context MLA : 97594.8747ms
~24% improvement
Summary by CodeRabbit
New Features
Bug Fixes
Tests