-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[None][feat] Add vLLM KV Pool support for XQA kernel #6013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
98d80f2
to
6709190
Compare
📝 WalkthroughWalkthroughThis update introduces support for an alternative "VLLM-style" paged KV cache layout in the XQA kernel module. It adds a new compile-time macro and build configuration, modifies kernel and utility code to support both the original and new cache layouts, updates function signatures and data structures, and expands documentation and tests to cover the new layout. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant BuildSystem
participant Kernel
participant DeviceMemory
User->>BuildSystem: Configure with -DPAGED_KV_CACHE_LAYOUT=0 or 1
BuildSystem->>Kernel: Compile with selected layout macro
User->>Kernel: Launch MHA (pass separate K/V pointers if layout=1)
Kernel->>DeviceMemory: Access KV cache using layout-specific logic
Kernel-->>User: Return results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~90 minutes Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (2)
🧰 Additional context used📓 Path-based instructions (2)**/*.{cpp,h,hpp,cc,cxx}📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Files:
**/*.{cpp,h,cu,py}📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Files:
🧠 Learnings (10)📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : do not use c library functions, whenever possible. use brace in...
Applied to files:
📚 Learning: applies to **/*.{h,hpp} : use a preprocessor guard in header files. the guard name must have prefix ...
Applied to files:
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : c headers should not be used directly. use instead of...
Applied to files:
📚 Learning: applies to **/*.{cpp,h,cu,py} : all tensorrt-llm open source software code should contain an nvidia ...
Applied to files:
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : abbreviation words, which are usually fully-capitalized in lite...
Applied to files:
📚 Learning: in tensorrt-llm testing, it's common to have both cli flow tests (test_cli_flow.py) and pytorch api ...
Applied to files:
📚 Learning: in tensorrt_llm/executor/worker.py, the lora adapter cache optimization logic that checks `is_adapte...
Applied to files:
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : when nesting preprocessor directives, use indentation after the...
Applied to files:
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : use only spaces for indentation. do not use tabs. indent 4 spac...
Applied to files:
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Applied to files:
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (24)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (4)
cpp/kernels/xqa/README.md (1)
28-28
: Fix heading punctuation style.Remove the trailing colon from the heading to follow Markdown style conventions.
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-Cachecpp/kernels/xqa/mha.cu (1)
1675-2746
: Consider adding compile-time assertions for tensor layout assumptions.The implementation correctly handles the VLLM KV cache format throughout. To improve robustness, consider adding static assertions to verify assumptions about tensor dimensions and memory layout, especially around the indexing calculations.
For example, near the indexing calculations:
+// Verify tensor layout assumptions at compile time +static_assert(tokensPerPage > 0, "tokensPerPage must be positive"); +static_assert(sizeof(GMemCacheHead) == sizeof(void*), "Pointer size assumption"); #if PAGED_KV_CACHE_FORMAT == 1 uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp;Also ensure the unit tests mentioned in the PR objectives thoroughly cover:
- Tensor dimension handling with various batch sizes and sequence lengths
- Memory addressing correctness for both K and V caches
- Performance comparison between traditional and VLLM formats
cpp/kernels/xqa/test/test.cpp (1)
1425-1432
: Good validation of configuration combinations.The filter correctly skips invalid combinations where
paged_kv_cache_format != 0
when paged KV cache is disabled. Consider translating the Chinese comment to English for consistency.- // 过滤无意义的组合:不使用分页KV缓存时,格式应该为0 + // Filter out invalid combinations: format should be 0 when not using paged KV cachecpp/kernels/xqa/mhaUtils.cuh (1)
294-311
: Simplified page indexing for vLLM format.The linear page index access correctly implements the vLLM format's 2D layout. Note that the
isK
parameter is unused in the vLLM format case since K and V share the same page indices.Consider adding a comment to clarify that
isK
is intentionally unused for the vLLM format:#if PAGED_KV_CACHE_FORMAT == 1 && USE_PAGED_KV_CACHE + // For vLLM format, K and V share the same page indices, so isK is not used ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[maxNbPagesPerSeq * idxReq + idxPage] : kBAD_PAGE_INDEX);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (43)
cpp/kernels/xqa/CMakeLists.txt (1)
45-45
: Good practice: Ensuring macro consistency across compilation units.Correctly adds the
PAGED_KV_CACHE_FORMAT
macro to CUDA compiler flags, ensuring the preprocessor definition is available during both host and device code compilation.cpp/kernels/xqa/defines.h (1)
100-106
: Well-structured conditional macro definition.The implementation correctly:
- Guards the definition with
USE_PAGED_KV_CACHE
to ensure it's only relevant for paged caches- Provides a sensible default value (0) for backward compatibility
- Allows override from the build system via
#ifndef
- Documents the meaning of each value clearly
cpp/kernels/xqa/README.md (2)
19-19
: Improved documentation completeness.Good addition of the
-DBUILD_XQA_TESTS=ON
flag to the build command, making the instructions complete for users who want to build and run tests.
28-37
: Comprehensive documentation for new feature.Excellent addition documenting the VLLM KV-Cache support. The instructions are clear and complete, providing both build and test commands that align with the implementation changes.
cpp/kernels/xqa/test/refAttention.h (1)
53-57
: Correct implementation of dual indexing strategies.The conditional indexing correctly handles both memory layouts:
- Original format:
tokensPerPage * nbHeads * pageIdx + tokensPerPage * idxHead + i % tokensPerPage
assumes [tokensPerPage][nbHeads] grouping- VLLM format:
nbHeads * tokensPerPage * pageIdx + (i % tokensPerPage) * nbHeads + idxHead
assumes [nbHeads][tokensPerPage] grouping (sequence-first)This aligns with the PR objective of supporting the tensor dimension layout change from
[bs, heads, tokens, head_dim]
to[bs, tokens, heads, head_dim]
.cpp/kernels/xqa/tensorMap.cpp (1)
61-77
: Correctly implements tensor map layouts for both formats.The conditional implementation properly handles the dimension reordering:
VLLM Layout (format 1):
- Global dims:
{headElems, nbKHeads, tokensPerPage, 1U << 31}
- Box dims:
{partElems, 1, min(tokensPerPage, nbTokensPerTile), 1}
XQA Original Layout (format 0):
- Global dims:
{headElems, tokensPerPage, nbKHeads, 1U << 31}
- Box dims:
{partElems, min(tokensPerPage, nbTokensPerTile), 1, 1}
The strides are calculated correctly for each layout, ensuring proper memory access patterns that align with the sequence-first vs. head-first tensor organization.
cpp/kernels/xqa/mha.h (2)
105-109
: LGTM! Clean conditional parameter implementation.The conditional compilation correctly introduces separate K and V cache pointers for the VLLM format while maintaining backward compatibility with the original pool-based approach.
144-148
: Consistent implementation with launchMHA.The same conditional parameter pattern is correctly applied to
launchHopperF8MHA
, maintaining consistency across both kernel launch functions.cpp/kernels/xqa/mha.cu (8)
1675-1681
: Correct index calculation for VLLM format.The conditional indexing correctly implements the tensor layout change from
[bs, heads, tokens, head_dim]
to[bs, tokens, heads, head_dim]
for the VLLM KV cache format. The calculation(seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp
properly reflects the tokens-before-heads memory layout.
1683-1688
: Appropriate K cache pointer selection.The conditional compilation correctly selects between
cacheList.kCacheVLLM
for the VLLM format andcacheList.pool
for the original format when BEAM_WIDTH == 1.
2010-2016
: Consistent V cache indexing with K cache.The V cache indexing follows the same pattern as K cache, maintaining consistency in the tensor layout transformation for the VLLM format.
2017-2023
: Symmetric V cache pointer handling.The V cache pointer selection mirrors the K cache implementation, correctly using
cacheList.vCacheVLLM
for VLLM format.
2025-2036
: Consistent beam search handling for V cache.The V cache beam search case follows the same pattern as K cache, maintaining implementation consistency.
2672-2676
: Function signature properly updated.The launchMHA implementation correctly accepts the separate K and V cache pointers when PAGED_KV_CACHE_FORMAT == 1, matching the header declaration.
1690-1701
: Ignore internal K/V handling – IndexedHeadPtr only stores the passed-in pool pointerIndexedHeadPtrImpl is a thin wrapper that holds the single
pool
pointer andpageIndices
you provide; it doesn’t attempt to distinguish K vs. V caches internally. Since you’re already selectingcacheList.kCacheVLLM
(orcacheList.pool
) via thePAGED_KV_CACHE_FORMAT
guard, no further changes are needed here.
2742-2746
: Approve KVCacheList initialization for both cache formatsThe aggregate initialization matches the
KVCacheList<true>
member layout under eachPAGED_KV_CACHE_FORMAT
setting:
- When
PAGED_KV_CACHE_FORMAT == 1
, you pass(kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq)
to initialize all five members.- Otherwise, you pass
(pool, kvCachePageList, seqLen, maxNbPagesPerSeq)
to initialize the four-member layout.No changes required.
cpp/kernels/xqa/test/test.cpp (18)
220-222
: LGTM!The assertion correctly enforces the requirement that K and V heads must be equal when using the vLLM paged KV cache format.
220-226
: Correct calculation for separate K/V cache pools.The total cache heads calculation correctly accounts for the vLLM format where K and V caches are stored in separate pools, each with
nbKHeads * maxSeqLen * batchSize
heads.
248-255
: Simplified page indexing for vLLM format.The page count calculation correctly reflects the vLLM format's 2D page layout
[batchSize][nbPagesPerSeq]
versus the traditional 4D layout.
293-299
: Proper allocation of separate K/V cache buffers.The code correctly allocates separate buffers for K and V caches when using the vLLM format, aligning with the architecture's requirement for distinct cache pools.
314-321
: Correct type casting for different page list layouts.The conditional type casting properly handles the vLLM format's 2D page list layout versus the traditional 4D layout.
322-340
: Well-structured page list initialization.The initialization logic correctly handles both formats with clear comments distinguishing the vLLM format's batch-wise filling from the original linear filling approach.
358-364
: Updated debug output for separate cache pools.The verbose output correctly displays separate K and V cache pointers when using the vLLM format.
456-462
: Consistent test data generation for separate caches.The code correctly generates test data for both K and V caches when using the vLLM format.
499-506
: Proper page shuffling for different layouts.The shuffle operation correctly handles the different page list layouts by using appropriate pointer arithmetic.
568-599
: Correct implementation of vLLM tensor indexing.The
cacheHeadAt
function properly implements the vLLM format's indexing scheme[batch, token, head, head_dim]
and correctly returns from separate K or V cache pools based on theisK
parameter.
624-630
: Proper memory prefetching for separate caches.The code correctly prefetches both K and V cache buffers to the device when using the vLLM format.
721-726
: Kernel launch correctly passes separate cache pointers.The kernel invocation properly provides both K and V cache pointers when using the vLLM format, matching the updated kernel signatures.
820-826
: Clear file naming for cache debugging.The code correctly saves K and V caches to separate files when using the vLLM format, facilitating debugging of the separate cache pools.
969-989
: Proper cache sequence initialization for reference checking.The reference checking code correctly initializes separate cache sequences for K and V when using the vLLM format, with appropriate page index access patterns.
1333-1336
: Extended NVRTC test coverage for cache formats.The test lambda correctly adds the
paged_kv_cache_format
parameter to enable testing of both traditional and vLLM cache formats.
1356-1356
: Correct NVRTC compilation flag for cache format.The compilation option properly defines the
PAGED_KV_CACHE_FORMAT
macro for NVRTC compilation.
1412-1412
: MLA-specific test for vLLM format.The test case appropriately validates the vLLM cache format with MLA architecture on SM120.
1432-1450
: Comprehensive test coverage for both cache formats.The expanded test loop correctly iterates over both cache format values and passes the parameter to the test function, ensuring thorough testing of both traditional and vLLM formats.
cpp/kernels/xqa/mhaUtils.cuh (2)
81-94
: Correct pointer arithmetic for vLLM tensor layout.The modified pointer arithmetic correctly implements the vLLM memory layout where tokens are laid out before heads:
[batch, token, head, head_dim]
. The calculation(i % tokensPerPage) * nbKHeads
properly reflects this ordering.
247-255
: Well-structured cache list for vLLM format.The specialization correctly defines separate K and V cache pointers for the vLLM format, with clear naming that indicates their purpose. Using
GMemCacheHead*
type is appropriate for the separate cache pools.cpp/kernels/xqa/mha_sm90.cu (7)
660-664
: LGTM! Kernel signature correctly updated for vLLM KV pool support.The conditional compilation properly adds separate tensor map parameters (
tensorMapVLLMK
andtensorMapVLLMV
) whenPAGED_KV_CACHE_FORMAT == 1
, which aligns with the PR's objective of supporting vLLM's separate K and V cache pools.
745-756
: LGTM! Tensor map prefetching correctly handles both formats.The prefetch logic properly handles both the vLLM format (prefetching two separate tensor maps) and the original format (single tensor map), while using warp election to ensure thread-safe execution.
1343-1353
: LGTM! K tile loader correctly uses K-specific tensor map.The KVTilePartLoader constructor is properly initialized with
tensorMapVLLMK
when using vLLM format, maintaining clear separation between K and V cache handling.
1418-1428
: LGTM! V tile loader correctly uses V-specific tensor map.The V tile loader initialization properly uses
tensorMapVLLMV
when using vLLM format, maintaining consistency with the K tile loader pattern.
1790-1794
: LGTM! Base offset calculation correctly reflects vLLM's separate K/V pool layout.The simplified offset calculation for vLLM format (
idxReq * cacheList.maxNbPagesPerSeq
) properly reflects the separate K and V pool architecture, removing the interleaved layout calculations used in the original format.
3141-3232
: LGTM! Host launcher properly supports dual K/V cache pools.The launcher correctly:
- Accepts separate
kCacheVLLM
andvCacheVLLM
parameters for vLLM format- Creates distinct tensor maps for K and V caches
- Maintains backward compatibility with the original single-pool format
3255-3259
: LGTM! Kernel launch correctly passes tensor maps.The
cudaLaunchKernelEx
call properly passes either the dual tensor maps (vLLM format) or single tensor map (original format), completing the integration with the modified kernel signature.
6709190
to
08229f5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
cpp/kernels/xqa/mha_sm90.cu (2)
814-831
: Critical issue persists: Incorrect tensor dimension ordering for vLLM KV cache.This is the same critical issue identified in previous reviews. The tensor dimension ordering for TMA loads under
PAGED_KV_CACHE_FORMAT == 1
is incorrect. The code currently placesidxHeadGrp
beforeoffset
, but vLLM's KV cache layout expects tokens before heads.The dimension ordering should be:
-DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t) pages[0]} +DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[0]}And similarly for the multi-page case on lines 826-827.
864-873
: Critical issue persists:getHead()
method not implemented for vLLM format.This is the same critical issue identified in previous reviews. The
getHead()
method throws an assertion and traps whenPAGED_KV_CACHE_FORMAT == 1
, but this method is still called in the code (lines 1386, 1394, 1460) for writing new tokens to the cache. This will cause runtime failures when using vLLM format.The implementation needs to be completed to support vLLM's memory layout or the calling code needs to be modified to avoid this method for vLLM format.
🧹 Nitpick comments (1)
cpp/kernels/xqa/README.md (1)
28-37
: Remove trailing punctuation from heading and improve formattingThe section provides clear instructions for the new VLLM feature, but there's a minor formatting issue with the heading.
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-CacheThe colon at the end of the heading violates markdown conventions as indicated by the static analysis tool.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/kernels/xqa/defines.h
🚧 Files skipped from review as they are similar to previous changes (6)
- cpp/kernels/xqa/CMakeLists.txt
- cpp/kernels/xqa/test/refAttention.h
- cpp/kernels/xqa/tensorMap.cpp
- cpp/kernels/xqa/mha.cu
- cpp/kernels/xqa/mhaUtils.cuh
- cpp/kernels/xqa/mha.h
🧰 Additional context used
🧠 Learnings (1)
cpp/kernels/xqa/mha_sm90.cu (1)
undefined
<retrieved_learning>
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
</retrieved_learning>
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (23)
cpp/kernels/xqa/README.md (1)
19-19
: LGTM: Build instruction improvementAdding
-DBUILD_XQA_TESTS=ON
by default makes sense and improves the developer experience by enabling unit tests in the standard build process.cpp/kernels/xqa/test/test.cpp (16)
220-225
: LGTM: Proper conditional compilation for cache formatThe conditional logic correctly handles the different memory layouts for the two KV cache formats. The assertion
assert(nbKHeads == nbVHeads)
is appropriate for the VLLM format as mentioned in the PR objectives.
248-254
: LGTM: Correct page allocation logic for VLLM formatThe page allocation correctly reflects the simplified structure in VLLM format where pages are indexed by
[batch][page]
instead of the original[batch][beam][kv][page]
structure.
293-298
: LGTM: Proper memory buffer allocation for split KV formatThe conditional allocation of separate K and V cache buffers for the VLLM format is implemented correctly, maintaining the same total memory usage while organizing it differently.
314-321
: LGTM: Correct page list casting and indexingThe type casting and pointer arithmetic correctly handle the different page list layouts between the two formats.
323-340
: LGTM: Page initialization logic correctly handles both formatsThe page initialization properly handles the different indexing schemes. The VLLM format uses a simpler 2D layout while the original format maintains the 4D structure.
456-461
: LGTM: Parallel cache generation respects format differencesThe thread-based cache generation correctly handles both unified and split cache buffer formats, ensuring proper initialization of the K and V caches separately when needed.
500-505
: LGTM: Page shuffling adapted for format differencesThe shuffling logic correctly operates on the appropriate data structures for each format, maintaining randomization while respecting the different memory layouts.
532-537
: LGTM: Zero-fill logic handles split buffers correctlyThe zero-fill initialization properly handles both unified and split cache buffer scenarios.
573-598
: LGTM: Cache accessor function handles both formats correctlyThe
cacheHeadAt
lambda correctly implements the different memory addressing schemes for both formats. The VLLM format uses separate K and V buffers with simplified indexing, while maintaining compatibility with the original format.
624-629
: LGTM: Device prefetching updated for split buffersThe prefetching logic correctly handles both unified and split cache buffer scenarios for GPU memory migration.
721-725
: LGTM: Kernel launch parameters updated for split KV formatThe kernel launch correctly passes separate K and V cache pointers when using the VLLM format, maintaining backward compatibility with the original unified buffer approach.
820-825
: LGTM: Data saving logic handles split format correctlyThe debug data saving functionality correctly handles both unified and split cache buffer formats for analysis purposes.
970-988
: LGTM: Reference checking updated for split cache formatThe reference checking logic correctly constructs separate
CacheSeq
objects for K and V caches when using the VLLM format, ensuring proper validation of the split cache implementation.
1333-1335
: LGTM: NVRTC test function signature updated correctlyThe test lambda signature correctly adds the new
paged_kv_cache_format
parameter while maintaining all existing parameters.
1356-1356
: LGTM: Compile-time macro properly passed to NVRTCThe new macro is correctly passed to the NVRTC compilation process, enabling conditional compilation of the different cache formats.
1425-1447
: LGTM: Test iteration logic correctly filters invalid combinationsThe nested loop structure correctly iterates over all parameter combinations while properly filtering out invalid configurations where paged KV cache is disabled but the format is non-zero. This prevents testing meaningless combinations.
cpp/kernels/xqa/mha_sm90.cu (6)
660-664
: LGTM: Kernel signature correctly supports separate tensor maps for vLLM format.The conditional compilation properly adds separate
tensorMapVLLMK
andtensorMapVLLMV
parameters forPAGED_KV_CACHE_FORMAT == 1
while maintaining backward compatibility with the originaltensorMap
parameter.
745-756
: LGTM: Tensor map prefetching correctly implemented for vLLM format.The conditional prefetching of separate
tensorMapVLLMK
andtensorMapVLLMV
follows proper CUDA patterns and maintains backward compatibility.
343-353
: LGTM: KVTilePartLoader constructor correctly adapted for vLLM format.The constructor properly accepts separate tensor maps for K and V caches when using vLLM format. The simplified
baseOffset
calculation for vLLM format (line 791) is correct since K and V caches are stored separately.Also applies to: 418-428
790-794
: LGTM: BaseOffset calculation correctly simplified for vLLM format.The simplified
baseOffset
calculation for vLLM format is appropriate since K and V caches are stored in separate pools, eliminating the need for the K/V distinction in the offset.
3140-3259
: LGTM: Host launcher correctly implemented for vLLM format.The host function properly:
- Accepts separate
kCacheVLLM
andvCacheVLLM
pointers for vLLM format- Creates separate tensor maps using
makeTensorMapForPagedKVCache
- Constructs
KVCacheList
with separate cache pointers- Passes the correct tensor maps to the kernel launch
The implementation maintains backward compatibility and follows consistent patterns with the kernel changes.
3254-3259
: LGTM: Kernel launch parameters correctly match signature changes.The conditional kernel launch parameters properly pass separate tensor maps for vLLM format while maintaining backward compatibility with the original single tensor map approach.
d844680
to
3355ecb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
cpp/kernels/xqa/README.md (1)
28-37
: Fix heading punctuation and improve documentation.The new section clearly explains how to build and test with VLLM KV cache support. However, there's a minor formatting issue with the heading punctuation.
Apply this diff to fix the heading punctuation:
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-Cache
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
🚧 Files skipped from review as they are similar to previous changes (8)
- cpp/kernels/xqa/CMakeLists.txt
- cpp/kernels/xqa/test/refAttention.h
- cpp/kernels/xqa/tensorMap.cpp
- cpp/kernels/xqa/mha.h
- cpp/kernels/xqa/mhaUtils.cuh
- cpp/kernels/xqa/mha.cu
- cpp/kernels/xqa/mha_sm90.cu
- cpp/kernels/xqa/defines.h
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (9)
cpp/kernels/xqa/test/test.cpp (8)
220-225
: LGTM: Conditional cache head calculation is correct.The conditional logic properly handles the different cache head calculations for VLLM format vs. original format. The assertion ensures the required constraint that K and V heads must be equal when using the new format.
248-254
: LGTM: Page list structure correctly adapted for VLLM format.The conditional page list structure properly reflects the VLLM format's simplified 2D layout compared to the original 4D layout. The memory allocation calculations are correctly adjusted.
293-298
: LGTM: Separate K/V cache buffers correctly implemented.The conditional allocation of separate
cacheKHeads
andcacheVHeads
buffers properly implements the VLLM format's requirement for split K/V cache pools.
322-340
: LGTM: Page list initialization correctly handles both formats.The conditional initialization properly implements the VLLM format's 2D page list structure while preserving the original format's behavior. The batch-first indexing aligns with the tensor layout changes.
456-461
: LGTM: Cache buffer operations correctly handle separate K/V pools.The conditional logic properly extends data generation, zero filling, and prefetching operations to work with separate K and V cache buffers in VLLM format.
568-599
: LGTM: Cache head accessor correctly implements VLLM format indexing.The
cacheHeadAt
function properly handles the different memory layouts and indexing schemes for both formats. The VLLM format's indexing correctly reflects the tensor layout change and properly selects between K and V cache buffers.
970-988
: LGTM: Reference checking correctly models VLLM format.The conditional creation of separate
CacheSeq
objects for K and V caches properly mirrors the kernel's behavior in VLLM format, ensuring accurate reference calculations for validation.
1333-1454
: LGTM: NVRTC tests comprehensively cover new format with proper filtering.The NVRTC test updates correctly add the
paged_kv_cache_format
parameter and compilation options. The filtering logic appropriately prevents invalid combinations while ensuring both formats are tested across different configurations.cpp/kernels/xqa/README.md (1)
19-19
: LGTM: Build instructions updated to enable tests by default.Good improvement to include
-DBUILD_XQA_TESTS=ON
in the standard build instructions, making it easier for developers to build and run tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
cpp/kernels/xqa/test/test.cpp (1)
970-988
: Action Required: Correct Reference Attention Calls for Paged (VLLM) LayoutThe reference test is currently constructing
CacheSeq<true,false>
for both K/V but still calling the non-paged reference kernel (refAttention<InputElem>
, which defaults to<false,false>
). This means your validation never exercises the paged logic.Please update as follows:
In cpp/kernels/xqa/test/test.cpp (around lines 970–988):
- Under
#if PAGED_KV_CACHE_LAYOUT == 1
, changetorefOutput = refAttention<InputElem>( … );refOutput = refAttention<InputElem, true, false>( … );- If you also test beam-search mode, invoke
refOutput = refAttention<InputElem, true, true>( … );In cpp/kernels/xqa/test/refAttention.cpp:
Add explicit paged instantiations after the existing ones:INSTANTIATE_refAttention(InputElem, true, false); INSTANTIATE_refAttention(InputElem, true, true);Verify in refAttention.h that the function template is declared as
template <typename Prec, bool isPaged = false, bool useBeamSearch = false> Eigen::Matrix<…> refAttention(…);(or adjust accordingly) so that the above calls compile.
These changes ensure your reference implementation truly validates both non-paged and VLLM (paged) cache layouts with separate K/V sequences.
♻️ Duplicate comments (2)
cpp/kernels/xqa/mha_sm90.cu (2)
1814-1831
: Critical: Incorrect tensor dimension ordering for vLLM KV cache.The dimension ordering in the TMA loads is incorrect for vLLM's KV cache layout. vLLM uses
[num_layers, 2, num_tokens, num_kv_heads, head_size]
layout, where tokens should be indexed before heads.The current implementation:
DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t) pages[i]}Should be:
DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[i]}This issue will cause incorrect memory accesses and data corruption when using the vLLM layout.
1864-1873
: Critical:getHead()
method is not implemented for vLLM format.The
getHead()
method throws an assertion and traps whenPAGED_KV_CACHE_LAYOUT == 1
, but this method is called in the code (e.g., lines 1386, 1394, 1460) to write new tokens to the cache. This will cause runtime failures when using vLLM format.The implementation needs to be completed to support vLLM's memory layout or the functionality needs to be properly documented as unsupported.
🧹 Nitpick comments (1)
cpp/kernels/xqa/README.md (1)
28-37
: Comprehensive documentation for the new VLLM KV-Cache feature.The new section clearly explains how to enable and test the VLLM Paged KV-Cache layout. The build instructions are complete and include all necessary flags.
However, address the markdown linting issue:
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-Cache
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
🧠 Learnings (1)
cpp/kernels/xqa/mha_sm90.cu (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🚧 Files skipped from review as they are similar to previous changes (7)
- cpp/kernels/xqa/CMakeLists.txt
- cpp/kernels/xqa/defines.h
- cpp/kernels/xqa/tensorMap.cpp
- cpp/kernels/xqa/test/refAttention.h
- cpp/kernels/xqa/mha.h
- cpp/kernels/xqa/mhaUtils.cuh
- cpp/kernels/xqa/mha.cu
🧰 Additional context used
🧠 Learnings (1)
cpp/kernels/xqa/mha_sm90.cu (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (20)
cpp/kernels/xqa/README.md (1)
19-19
: Good improvement to enable tests by default.Adding
-DBUILD_XQA_TESTS=ON
to the default build command makes it easier for developers to build and run tests without having to remember the flag.cpp/kernels/xqa/test/test.cpp (14)
220-225
: Correct conditional logic for cache head count calculation.The implementation properly separates the total cache head count calculation for the new VLLM layout, ensuring K and V heads are handled separately when
PAGED_KV_CACHE_LAYOUT == 1
.
248-254
: Proper page list structure adaptation for VLLM layout.The conditional compilation correctly adapts the page list structure:
- VLLM layout uses 2D indexing:
pageList[batchIdx][pageIdx]
- Original layout uses 4D indexing:
pageList[batchIdx][beamIdx][kvIdx][pageIdx]
This matches the expected memory layout changes described in the PR objectives.
293-298
: Appropriate memory buffer separation for K and V caches.The conditional allocation of separate
cacheKHeads
andcacheVHeads
buffers aligns with the VLLM KV pool architecture where K and V caches are stored in distinct blocks rather than interleaved format.
314-321
: Consistent page list pointer handling across layouts.The page list reinterpretation and argument preparation correctly handles both the VLLM format (2D array) and original format (4D array), ensuring proper memory addressing for each layout.
456-461
: Correct conditional data generation for separate K/V caches.The thread-based data generation properly handles the separated K and V cache buffers in the VLLM layout, ensuring both buffers receive the same amount of generated data.
532-537
: Consistent zero-fill handling for both cache layouts.The zero-fill logic correctly handles both the separated K/V cache buffers (VLLM) and the combined cache buffer (original), ensuring proper initialization in both cases.
568-598
: Complex but correct cache head accessor logic.The
cacheHeadAt
lambda function properly handles the different indexing schemes:
- VLLM layout: Separate K/V buffers with simplified page indexing
- Original layout: Combined buffer with beam/KV indexing
The tensor dimension changes from
[bs, heads, tokens, head_dim]
to[bs, tokens, heads, head_dim]
are correctly implemented in the VLLM branch.
624-629
: Proper memory prefetching for separated buffers.The prefetch logic correctly handles both cache layouts, ensuring all allocated memory buffers are properly prefetched to the device.
721-725
: Kernel launch parameter adaptation for VLLM layout.The kernel launch correctly passes separate K and V cache pointers when using the VLLM layout, matching the kernel signature changes described in the PR objectives.
820-825
: Appropriate data saving for debugging/validation.The conditional data saving logic properly handles both cache layouts for debugging and validation purposes, maintaining compatibility with existing tooling.
1333-1335
: Enhanced NVRTC test parameter coverage.The addition of
paged_kv_cache_layout
parameter to the NVRTC compilation test lambda ensures that the new layout option is tested during compilation validation.
1356-1356
: Proper compile-time macro definition for NVRTC.The addition of the
PAGED_KV_CACHE_LAYOUT
macro definition in the NVRTC compilation options ensures the test covers the new conditional compilation paths.
1425-1449
: Comprehensive test coverage for new layout parameter.The nested loop structure now includes
paged_kv_cache_layout
parameter testing with proper filtering logic to skip invalid combinations. The condition!use_paged_kv_cache && paged_kv_cache_layout != 0
correctly prevents testing non-zero layout values when paged KV cache is disabled.
323-340
: Page index initialization logic verifiedThe sequential (VLLM) and linear (original) filling implementations both produce the correct global page indices for their respective layouts, and they align with all observed access patterns in
test.cpp
,refAttention.h
, and the kernel code. No changes required.cpp/kernels/xqa/mha_sm90.cu (5)
660-664
: LGTM: Kernel signature correctly supports separate K/V tensor maps.The conditional compilation properly handles vLLM's separate K and V cache tensor maps while maintaining backward compatibility with the original unified tensor map approach.
745-756
: LGTM: Tensor map prefetching correctly handles separate K/V maps.The conditional prefetching of both
tensorMapVLLMK
andtensorMapVLLMV
is properly implemented and maintains the performance optimization for the new vLLM layout.
790-794
: LGTM: BaseOffset calculation correctly adapts to vLLM's separate K/V pool architecture.The simplified baseOffset calculation for
PAGED_KV_CACHE_LAYOUT == 1
correctly reflects vLLM's separate K and V cache pools, eliminating the need for K/V interleaving offset calculations.
1343-1353
: LGTM: KVTilePartLoader constructor calls correctly use separate tensor maps.The constructor calls properly pass the appropriate tensor maps (
tensorMapVLLMK
for K loader,tensorMapVLLMV
for V loader) when using the vLLM layout, maintaining clean separation between K and V cache handling.Also applies to: 1418-1428
3140-3259
: LGTM: Host function correctly implements separate K/V cache support.The host function changes properly:
- Accept separate
kCacheVLLM
andvCacheVLLM
parameters for the vLLM layout- Create separate tensor maps for K and V caches using
makeTensorMapForPagedKVCache
- Pass the appropriate tensor maps to the kernel launch
- Maintain backward compatibility through conditional compilation
The implementation correctly bridges the host-device interface for vLLM's separate K/V cache architecture.
d33eb32
to
8566621
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
cpp/kernels/xqa/mha_sm90.cu (2)
1814-1831
: Critical: Tensor dimension ordering still incorrect for vLLM KV cache.The TMA load operations continue to place heads before tokens in the vLLM layout, contradicting vLLM's expected
[num_layers, 2, num_tokens, num_kv_heads, head_size]
tensor format where tokens should come before heads.
1864-1873
: Critical: getHead() method remains unimplemented for vLLM format.The
getHead()
method still contains only an assertion and trap for vLLM format, but this method is required for writing new tokens to the cache. This will cause runtime failures when using vLLM format.
🧹 Nitpick comments (1)
cpp/kernels/xqa/README.md (1)
28-37
: Remove trailing colon from heading and improve documentation clarity.The documentation clearly explains the new VLLM Paged KV-Cache feature. However, there's a minor style issue with the heading punctuation.
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-CacheConsider also clarifying what "split-wise KV-pool and sequence-first memory layout" means for users who may not be familiar with these terms.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
🧠 Learnings (1)
cpp/kernels/xqa/mha_sm90.cu (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🚧 Files skipped from review as they are similar to previous changes (7)
- cpp/kernels/xqa/CMakeLists.txt
- cpp/kernels/xqa/defines.h
- cpp/kernels/xqa/test/refAttention.h
- cpp/kernels/xqa/tensorMap.cpp
- cpp/kernels/xqa/mhaUtils.cuh
- cpp/kernels/xqa/mha.h
- cpp/kernels/xqa/mha.cu
🧰 Additional context used
🧠 Learnings (1)
cpp/kernels/xqa/mha_sm90.cu (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (22)
cpp/kernels/xqa/README.md (1)
19-19
: LGTM! Good addition of default test building.The addition of
-DBUILD_XQA_TESTS=ON
to the default build command makes it easier for developers to build with tests enabled by default.cpp/kernels/xqa/test/test.cpp (14)
248-254
: LGTM! Page list size calculation correctly handles the new layout.The conditional logic appropriately reduces the page list size for the VLLM layout since it doesn't need separate beam and KV dimensions in the page indexing.
293-298
: Good separation of K and V cache buffers for the new layout.The conditional allocation of separate
cacheKHeads
andcacheVHeads
buffers aligns with the vLLM KV Pool architecture described in the PR objectives.
456-461
: LGTM! Proper conditional initialization of separate cache buffers.The thread-based initialization correctly handles both the original unified cache format and the new separate K/V cache format.
532-537
: Consistent zero-filling logic for both cache layouts.The zero-fill initialization properly handles both unified and separate cache buffer layouts.
624-629
: LGTM! Proper prefetching for separate cache buffers.The conditional prefetching correctly handles both unified and separate K/V cache buffer scenarios.
721-725
: Good adaptation of kernel launch parameters.The kernel launch correctly passes separate K and V cache pointers when the new layout is enabled, which aligns with the PR objectives.
820-825
: LGTM! Consistent data saving for both layouts.The conditional saving of K and V data separately for the new layout versus combined KV data for the original layout is appropriate.
1333-1336
: Good addition of paged_kv_cache_layout parameter to compilation tests.The NVRTC test lambda now correctly accepts and uses the new
paged_kv_cache_layout
parameter, ensuring the new layout gets proper compile-time testing.
1425-1430
: LGTM! Proper filtering of invalid parameter combinations.The logic correctly skips testing non-zero
paged_kv_cache_layout
values whenuse_paged_kv_cache
is false, preventing invalid configuration testing.
1431-1448
: Comprehensive test coverage for the new layout.The nested loops ensure that the new
paged_kv_cache_layout
parameter is tested across all relevant combinations of input types, cache formats, head dimensions, and beam widths.
323-339
: Initialization logic verified for both layouts
The sequential index assignment correctly covers all elements in each layout:
- PAGED_KV_CACHE_LAYOUT == 1 uses a 2D nested loop (
pageList[batch][page] = pageIdx++
) overbatchSize × nbPagesPerSeq
.- Else uses a flattened pointer (
(&pageList[0][0][0][0])[i] = i
) overbeamWidth × 2 × nbPagesPerSeq
.No changes required.
573-598
: Cache indexing verified – no issues foundI’ve cross-checked the new formulas in
test.cpp
against the patterns inrefAttention.h
. They line up exactly for both layouts:
- PAGED_KV_CACHE_LAYOUT == 1:
pageIdx * tokensPerPage * nbKHeads + (pos % tokensPerPage) * nbKHeads + idxKVHead- VLLM layout:
tokensPerPage * (nbKHeads * pageIdx + idxKVHead) + (pos % tokensPerPage)No inconsistencies were found and no changes are needed here.
970-988
: No changes needed: CacheSeq construction matches expected specializations
- For
PAGED_KV_CACHE_LAYOUT == 1
,CacheSeq<true,false>
is correctly initialized with
.pool
→GMemCacheHead const*
.pageIndices
→KVCachePageIndex*
.nbHeads
and.idxHead
matching the<true,false>
specialization.- For the other layout,
CacheSeq<true,true>
is likewise provided the exact fields (.pool
,.pageIndices
,.maxNbPages
,.nbHeads
,.idxHead
,.cacheIndir
) required by its specialization.All constructor arguments align with their struct definitions—no fixes required.
220-225
: Double-check cache head count for the new PAGED_KV_CACHE_LAYOUTUnder
#if PAGED_KV_CACHE_LAYOUT == 1 assert(nbKHeads == nbVHeads); uint32_t const totalNbCacheHeads = nbKHeads * maxSeqLen * batchSize; #else uint32_t const totalNbCacheHeads = (nbKHeads + nbVHeads) * maxSeqLen * beamWidth * batchSize; #endifwe’re asserting
nbKHeads == nbVHeads
and then only multiplying bynbKHeads
. Since K and V are stored in separate buffers (cacheKHeads
andcacheVHeads
) and factored into memory-traffic & size calculations (e.g.cacheBytes
,pageListBytes
), please verify that omitting the “×2” (or(nbKHeads+nbVHeads)
) here doesn’t undercount one of the caches.Areas to confirm:
- Are both
cacheKHeads
andcacheVHeads
being sized/per-head elements (validElemsPerKHead
) based on this sametotalNbCacheHeads
?- Does
cacheBytes = cacheElemSize * totalNbCacheElems
correctly include both K and V?- Should the paged layout also double its page list (i.e.
totalNbPages = nbPagesPerSeq * 2 * batchSize
)?If you intend to keep them separate but counted together, consider:
#if PAGED_KV_CACHE_LAYOUT == 1 - uint32_t const totalNbCacheHeads = nbKHeads * maxSeqLen * batchSize; + uint32_t const totalNbCacheHeads = (nbKHeads + nbVHeads) * maxSeqLen * batchSize; #else uint32_t const totalNbCacheHeads = (nbKHeads + nbVHeads) * maxSeqLen * beamWidth * batchSize; #endif…and adjust downstream buffers/metrics accordingly.
cpp/kernels/xqa/mha_sm90.cu (7)
660-664
: LGTM! Kernel signature correctly supports separate tensor maps for vLLM layout.The conditional kernel signature modification properly handles the separate K and V tensor maps required for the vLLM paged KV cache layout while maintaining backward compatibility.
745-756
: LGTM! Tensor map prefetching correctly handles both layouts.The conditional prefetching logic properly handles both the original single tensor map and the separate K/V tensor maps for vLLM layout.
1343-1353
: LGTM! K cache loader correctly uses appropriate tensor map.The conditional logic properly passes the K-specific tensor map for vLLM layout while maintaining compatibility with the original layout.
1418-1428
: LGTM! V cache loader correctly uses appropriate tensor map.The conditional logic properly passes the V-specific tensor map for vLLM layout while maintaining compatibility with the original layout.
1790-1794
: LGTM! Base offset calculation correctly reflects vLLM's separate K/V pool architecture.The simplified base offset calculation for vLLM layout correctly accounts for the fact that K and V caches are stored in entirely separate pools, eliminating the need for K/V distinction within the offset calculation.
3140-3144
: LGTM! Host function correctly accepts separate K/V cache pointers for vLLM.The conditional function signature modification properly handles the separate K and V cache pointers required for vLLM's architecture while maintaining backward compatibility.
3220-3259
: LGTM! Host function correctly creates separate tensor maps and launches kernel.The conditional logic properly creates separate tensor maps for K and V caches in vLLM layout and passes the appropriate parameters to the kernel launch. The implementation maintains backward compatibility with the original layout.
6d4db08
to
efca057
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
cpp/kernels/xqa/README.md (1)
28-37
: Clear documentation for the new vLLM feature.The new section provides clear instructions for building and testing with the vLLM KV cache layout. The documentation correctly explains the feature's purpose and provides complete build commands.
Minor style fix needed:
Remove the trailing colon from the heading to comply with markdown best practices:
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-Cache
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🚧 Files skipped from review as they are similar to previous changes (8)
- cpp/kernels/xqa/test/refAttention.h
- cpp/kernels/xqa/tensorMap.cpp
- cpp/kernels/xqa/CMakeLists.txt
- cpp/kernels/xqa/defines.h
- cpp/kernels/xqa/mhaUtils.cuh
- cpp/kernels/xqa/mha.h
- cpp/kernels/xqa/mha.cu
- cpp/kernels/xqa/mha_sm90.cu
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (15)
cpp/kernels/xqa/test/test.cpp (14)
220-225
: LGTM! Cache head calculation correctly handles the new layout.The conditional logic properly calculates
totalNbCacheHeads
for the vLLM layout where K and V caches are separated, requiring onlynbKHeads * maxSeqLen * batchSize
instead of the combined(nbKHeads + nbVHeads) * maxSeqLen * beamWidth * batchSize
for the original layout.
248-254
: Page list calculation is correctly adapted for vLLM layout.The logic properly handles the simplified page structure in vLLM format where pages are indexed as
[batchIdx][pageIdx]
instead of the original 4D structure[batch][beam][kv][page]
.
293-298
: Separate K and V cache buffer allocation is correctly implemented.The conditional compilation properly creates separate
cacheKHeads
andcacheVHeads
buffers for the vLLM layout while maintaining the original unifiedcacheHeads
buffer for the default layout.
314-321
: Page list buffer casting handles both layouts correctly.The code properly casts the page list buffer to the appropriate structure - simplified 2D array for vLLM layout vs 4D array for original layout.
323-340
: Page list initialization logic is well-structured.The initialization correctly handles both formats:
- vLLM format: Sequential page assignment per batch with clear indexing
- Original format: Linear filling as before
The comments clearly document the different approaches.
456-461
: Thread-based cache initialization handles both buffer types.The parallel cache filling logic correctly generates data for both separate K/V buffers in vLLM layout and the unified buffer in the original layout.
532-537
: Zero-fill initialization correctly handles separated buffers.The zero-fill logic properly initializes both
cacheKHeads
andcacheVHeads
when using the vLLM layout.
573-598
: Cache accessor function correctly implements vLLM indexing.The
cacheHeadAt
function properly handles the vLLM layout:
- Simplified page indexing without beam/kv dimensions
- Correct tensor layout:
pageIdx * tokensPerPage * nbKHeads + (pos % tokensPerPage) * nbKHeads + idxKVHead
- Proper selection between K and V buffers based on the
isK
parameterThis matches the expected vLLM tensor format described in the PR objectives.
624-629
: Memory prefetching correctly handles separate buffers.The prefetch logic properly handles both
cacheKHeads
andcacheVHeads
buffers when using the vLLM layout.
721-725
: Kernel invocation correctly passes separate cache pointers.The kernel launch properly passes both
cacheKHeads.get()
andcacheVHeads.get()
when using the vLLM layout, which aligns with the expected kernel signature changes.
820-825
: Data saving logic handles both buffer layouts.The save functionality correctly exports separate K and V cache data when using vLLM layout, which is useful for debugging and validation.
970-988
: Reference checking creates proper cache sequences for vLLM layout.The code correctly creates separate
CacheSeq
objects for K and V caches when using vLLM layout, with proper page indices and buffer references. This ensures the reference implementation matches the kernel behavior.
1333-1356
: NVRTC compilation test properly includes the new parameter.The test function signature correctly adds the
paged_kv_cache_layout
parameter and passes it as a compile-time define. This ensures the runtime compilation testing covers the new layout option.
1425-1450
: Test iteration logic correctly filters invalid combinations.The nested loops properly test both layout options (0 and 1) while correctly skipping invalid combinations where paged KV cache is disabled but layout is non-zero. This prevents testing nonsensical configurations.
cpp/kernels/xqa/README.md (1)
19-19
: Good addition of default test build flag.Adding
-DBUILD_XQA_TESTS=ON
to the default build instructions ensures users can easily build and run the unit tests without needing to discover this flag separately.
/bot run |
PR_Github #12923 [ run ] triggered by Bot |
PR_Github #12923 [ run ] completed with state |
/bot run |
PR_Github #12967 [ run ] triggered by Bot |
PR_Github #12967 [ run ] completed with state |
/bot run |
c9077be
to
c4f685d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
cpp/kernels/xqa/mha_sm90.cu (2)
1814-1831
: Critical: Tensor dimension ordering issue persists.The dimension ordering in TMA load operations is still incorrect for vLLM layout. As noted in previous reviews, vLLM expects tokens before heads in the tensor layout, but the current implementation places
idxHeadGrp
beforeoffset
(tokens), which breaks vLLM's expected layout.The correct ordering should be:
-DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t) pages[0]} +DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[0]}Apply this fix to both single-page (lines 1815) and multi-page (lines 1826-1827) cases.
1864-1873
: Critical: getHead() method not implemented for vLLM format.As previously noted, the
getHead()
method throws an assertion and traps whenPAGED_KV_CACHE_LAYOUT == 1
, but this method is called in the code (e.g., lines 1386, 1394, 1460) to write new tokens to the cache. This will cause runtime failures when using vLLM format.The implementation needs to be completed to support vLLM's memory layout or the calling sites need to be conditionally compiled out if this functionality is not needed for vLLM format.
🧹 Nitpick comments (1)
cpp/kernels/xqa/README.md (1)
28-28
: Remove trailing punctuation from heading.According to markdown style guidelines, headings should not end with punctuation.
Apply this diff to fix the heading:
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-Cache
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/kernels/xqa/CMakeLists.txt
🚧 Files skipped from review as they are similar to previous changes (6)
- cpp/kernels/xqa/test/refAttention.h
- cpp/kernels/xqa/tensorMap.cpp
- cpp/kernels/xqa/defines.h
- cpp/kernels/xqa/mhaUtils.cuh
- cpp/kernels/xqa/mha.h
- cpp/kernels/xqa/mha.cu
🧰 Additional context used
📓 Path-based instructions (2)
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo)
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in the compilation of a target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables use camel case prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope and function-scope magic-number/liter...
Files:
cpp/kernels/xqa/test/test.cpp
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/kernels/xqa/test/test.cpp
cpp/kernels/xqa/mha_sm90.cu
🧠 Learnings (3)
cpp/kernels/xqa/test/test.cpp (5)
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-28T15:39:50.377Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Do not use C library functions, whenever possible. Use brace initialization or std::fill_n() instead of memset().
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-28T15:39:50.377Z
Learning: Applies to **/*.{h,hpp} : Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-28T15:39:50.377Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-28T15:39:50.377Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : C headers should not be used directly. Use C++ standard headers like instead of <stdint.h>.
Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
cpp/kernels/xqa/mha_sm90.cu (4)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-28T15:39:50.377Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : When nesting preprocessor directives, use indentation after the hash mark (#).
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-28T15:39:50.377Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use only spaces for indentation. Do not use tabs. Indent 4 spaces at a time.
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor()
is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation()
to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
cpp/kernels/xqa/README.md (1)
Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (21)
cpp/kernels/xqa/README.md (2)
19-19
: LGTM! Good addition of default test flag.Adding
-DBUILD_XQA_TESTS=ON
to the standard build command makes it easier for users to build and run tests by default.
29-37
: Excellent documentation of the new VLLM feature.The documentation clearly explains:
- What the
PAGED_KV_CACHE_LAYOUT=1
option enables- Complete build and test commands including the new flag
- Proper integration with existing build instructions
This will help users understand and utilize the new VLLM Paged KV-Cache layout feature.
cpp/kernels/xqa/test/test.cpp (13)
220-225
: LGTM! Proper conditional memory allocation for VLLM layout.The code correctly handles the different memory allocation requirements:
- For VLLM layout (
PAGED_KV_CACHE_LAYOUT == 1
): Separate K and V heads with simplified calculation- For original layout: Combined K+V heads with beam width consideration
The assertion
nbKHeads == nbVHeads
is appropriate for the VLLM layout constraint.
248-254
: Correct page list sizing for different layouts.The conditional logic properly accounts for the different page list structures:
- VLLM layout:
nbPagesPerSeq * batchSize
(2D: batch × page)- Original layout:
nbPagesPerSeq * 2 * beamWidth * batchSize
(4D: batch × beam × kv × page)This matches the expected memory layout differences.
293-298
: Proper conditional buffer allocation for separate K/V caches.The code correctly allocates separate
cacheKHeads
andcacheVHeads
buffers for the VLLM layout while maintaining the originalcacheHeads
buffer for the traditional layout.
314-340
: Well-implemented page list initialization logic.The code properly handles two different page indexing schemes:
- VLLM format: 2D array
pageList[batchIdx][pageIdx]
with sequential page assignment- Original format: Linear filling with 4D indexing
The implementation is clear and matches the expected memory layouts.
456-461
: Consistent cache initialization across both layouts.The conditional logic properly initializes cache data for both layouts:
- VLLM layout: Separate initialization of K and V cache heads
- Original layout: Single combined cache head initialization
The parallel initialization pattern is maintained for both cases.
Also applies to: 532-537
568-599
: Robust cache access abstraction with proper indexing.The
cacheHeadAt
lambda function provides excellent abstraction for cache access with proper conditional logic:
- VLLM layout: Separate indexing for K/V caches with simplified page calculation
- Original layout: Combined cache with traditional indexing
The implementation correctly handles the different tensor layouts and returns appropriate references.
624-629
: Proper memory prefetching for both layouts.The prefetching logic correctly handles both cache buffer configurations, ensuring optimal GPU memory access patterns for both the VLLM and original layouts.
721-725
: Correct kernel launch parameter adaptation.The code properly passes separate K/V cache pointers (
cacheKHeads.get(), cacheVHeads.get()
) for the VLLM layout while maintaining the original single cache pointer for the traditional layout.
820-825
: Appropriate data saving logic for debugging.The conditional data saving properly handles the different cache layouts, saving separate K and V data for VLLM layout and combined KV data for the original layout.
970-988
: Proper reference checking setup for VLLM layout.The code correctly constructs separate
CacheSeq
objects for K and V caches when using VLLM layout, with appropriate page indices and head parameters. This ensures accurate reference checking against the expected behavior.
1333-1356
: Enhanced NVRTC test with paged KV cache layout parameter.The test lambda now properly accepts the
paged_kv_cache_layout
parameter and passes it as a compile-time define, enabling comprehensive testing of both layout variants.
1425-1430
: Proper test parameter validation.The code correctly filters invalid combinations where paged KV cache is disabled but the layout parameter is non-zero. This prevents meaningless test configurations and ensures test coverage accuracy.
1412-1412
: Update test call to include new parameter.The MLA test call should be updated to include the
paged_kv_cache_layout
parameter for consistency with the new function signature.Apply this diff to fix the function call:
- test(0, 2, 576, 128, true, 1, 1, tensorrt_llm::kernels::mla_sm120_cu_content, 12, 0); + test(0, 2, 576, 128, true, 1, 1, tensorrt_llm::kernels::mla_sm120_cu_content, 12, 0);Wait, looking more carefully at the lambda definition on line 1334-1335, the parameter is already included correctly. The call on line 1412 appears to match the expected signature. Let me reconsider...
Actually, the function call on line 1412 already includes all required parameters:
(input_fp16, cache_enum, head_dim, head_grp_size, use_paged_kv_cache, paged_kv_cache_layout, beam_width, source_file, compileMajor, compileMinor)
. The calltest(0, 2, 576, 128, true, 1, 1, ...)
corresponds topaged_kv_cache_layout=1
as the 6th parameter, which is correct.cpp/kernels/xqa/mha_sm90.cu (6)
660-664
: LGTM: Clean kernel signature extension for vLLM support.The conditional compilation approach maintains backward compatibility while adding the necessary separate tensor maps for K and V caches in vLLM layout. The parameter naming is clear and consistent.
745-756
: LGTM: Correct tensor map prefetching for both layouts.The conditional prefetching logic appropriately handles both the original single tensor map and the new separate K/V tensor maps for vLLM layout, which is necessary for optimal TMA performance.
1341-1354
: LGTM: Correct tensor map selection for K cache loader.The conditional logic correctly selects
tensorMapVLLMK
for vLLM layout and falls back to the originaltensorMap
for backward compatibility.
1416-1429
: LGTM: Correct tensor map selection for V cache loader.The conditional logic correctly selects
tensorMapVLLMV
for vLLM layout and falls back to the originaltensorMap
for backward compatibility.
1790-1794
: LGTM: Correct baseOffset calculation for vLLM layout.The baseOffset calculation correctly removes the K/V differentiation for vLLM layout since K and V caches are separate, using only the request-based offset.
3140-3259
: LGTM: Comprehensive host-side implementation for vLLM support.The host function changes correctly implement support for vLLM's separate K/V cache layout:
- Function signature appropriately adds separate cache pointers for vLLM layout
- KVCacheList construction is properly adapted for both layouts
- Tensor map creation uses the correct cache pointers
- Kernel launch parameters are correctly passed
- Backward compatibility is maintained through conditional compilation
The implementation follows a consistent pattern and appears functionally correct.
Please address the nitpick comment from coderabbit regarding the README. |
c4f685d
to
69ddf0f
Compare
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
cpp/kernels/xqa/mha_sm90.cu (2)
1814-1831
: Critical: Incorrect tensor dimension ordering for vLLM KV cacheThe dimension ordering in TMA loads is incorrect for vLLM layout. According to vLLM's KV cache format
[num_layers, 2, num_tokens, num_kv_heads, head_size]
, tokens should be indexed before heads, but the current code places heads (idxHeadGrp
) before tokens (offset
).Apply this fix to correct the dimension ordering:
#if PAGED_KV_CACHE_LAYOUT == 1 - tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t) pages[0]}, bar); + tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[0]}, bar); #else tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[0]}, bar); #endifAnd similarly for the multi-page case:
#if PAGED_KV_CACHE_LAYOUT == 1 - tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, - DimsLE<4>{partElems * idxPart, idxHeadGrp, 0, (uint32_t) pages[i]}, bar); + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t) pages[i]}, bar); #endif
1864-1873
: Critical:getHead()
method not implemented for vLLM formatThe
getHead()
method throws an assertion and traps whenPAGED_KV_CACHE_LAYOUT == 1
, but this method is still called in the code (e.g., lines 1386, 1394, 1460) to write new tokens to the cache. This will cause runtime failures when using vLLM format.The implementation needs to be completed to support vLLM's memory layout, or the calling code needs to be made conditional to avoid calling this method for vLLM format. Consider implementing the proper address calculation for vLLM's separate K/V pool architecture or documenting why this functionality is not needed for vLLM format.
🧹 Nitpick comments (1)
cpp/kernels/xqa/README.md (1)
28-37
: Fix markdown linting issue in the heading.The documentation for the VLLM Paged KV-Cache feature is clear and helpful. However, there's a markdown linting issue with the trailing colon in the heading.
-## Support for VLLM Paged KV-Cache: +## Support for VLLM Paged KV-Cache
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
cpp/kernels/xqa/CMakeLists.txt
(2 hunks)cpp/kernels/xqa/README.md
(2 hunks)cpp/kernels/xqa/defines.h
(1 hunks)cpp/kernels/xqa/mha.cu
(4 hunks)cpp/kernels/xqa/mha.h
(2 hunks)cpp/kernels/xqa/mhaUtils.cuh
(3 hunks)cpp/kernels/xqa/mha_sm90.cu
(10 hunks)cpp/kernels/xqa/tensorMap.cpp
(1 hunks)cpp/kernels/xqa/test/refAttention.h
(1 hunks)cpp/kernels/xqa/test/test.cpp
(18 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/kernels/xqa/CMakeLists.txt
🚧 Files skipped from review as they are similar to previous changes (6)
- cpp/kernels/xqa/defines.h
- cpp/kernels/xqa/test/refAttention.h
- cpp/kernels/xqa/tensorMap.cpp
- cpp/kernels/xqa/mhaUtils.cuh
- cpp/kernels/xqa/mha.h
- cpp/kernels/xqa/mha.cu
🧰 Additional context used
🧠 Learnings (3)
cpp/kernels/xqa/test/test.cpp (1)
Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
cpp/kernels/xqa/mha_sm90.cu (2)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor()
is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation()
to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
cpp/kernels/xqa/README.md (1)
Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
🪛 markdownlint-cli2 (0.17.2)
cpp/kernels/xqa/README.md
28-28: Trailing punctuation in heading
Punctuation: ':'
(MD026, no-trailing-punctuation)
🔇 Additional comments (25)
cpp/kernels/xqa/README.md (1)
19-19
: LGTM! Good improvement to enable tests by default.Adding
-DBUILD_XQA_TESTS=ON
to the standard build command makes the testing workflow more accessible to developers.cpp/kernels/xqa/test/test.cpp (17)
220-225
: LGTM! Proper conditional memory allocation for different layouts.The conditional allocation correctly handles the separation of K and V cache pools when
PAGED_KV_CACHE_LAYOUT == 1
, with appropriate assertion to ensure K and V heads match.
248-254
: LGTM! Correct page list sizing for different layouts.The conditional sizing properly accounts for the different page list structures between the original format (4D indexing) and VLLM format (2D indexing).
293-298
: LGTM! Proper conditional buffer allocation.The separation of K and V cache buffers when
PAGED_KV_CACHE_LAYOUT == 1
is correctly implemented with appropriate conditional compilation.
314-321
: LGTM! Correct page list casting for different layouts.The conditional casting properly handles the different page list structures - 2D array for VLLM format vs 4D array for original format.
323-340
: LGTM! Proper page list initialization for different layouts.The VLLM format uses a cleaner 2D indexing pattern
pageList[batch][page]
compared to the original 4D linear filling, which aligns with the expected memory layout changes.
456-461
: LGTM! Consistent cache generation for split K/V buffers.The parallel generation of cache data for separate K and V buffers maintains consistency with the original single buffer approach.
500-505
: LGTM! Proper shuffling for different page list layouts.The shuffling correctly targets the appropriate page list structure based on the layout configuration.
532-537
: LGTM! Consistent zero-filling for split K/V buffers.The zero-filling logic properly handles both separate K/V buffers and the combined buffer approach.
573-598
: LGTM! Correct cache head access logic for different layouts.The
cacheHeadAt
function properly handles the different indexing schemes and returns references from the appropriate buffer (K or V) based on the layout and theisK
parameter.
624-629
: LGTM! Proper prefetching for split K/V buffers.The prefetching logic correctly handles both separate K/V buffers and the combined buffer approach.
721-725
: LGTM! Correct kernel launch parameters for different layouts.The kernel launch properly passes separate K and V cache pointers when
PAGED_KV_CACHE_LAYOUT == 1
is enabled.
820-825
: LGTM! Proper data saving for different layouts.The data saving logic correctly handles separate K and V cache files when the layout is enabled.
970-988
: LGTM! Correct reference checking with separate cache sequences.The reference checking properly constructs separate
CacheSeq
objects for K and V caches whenPAGED_KV_CACHE_LAYOUT == 1
, which aligns with the split buffer architecture.
1333-1336
: LGTM! Enhanced NVRTC test parameters for new layout.The addition of the
paged_kv_cache_layout
parameter to the test lambda properly extends the compilation testing to cover the new layout option.
1425-1430
: LGTM! Proper test filtering for invalid combinations.The logic correctly skips invalid combinations where paged KV cache is disabled but the layout parameter is non-zero, preventing meaningless test configurations.
1356-1356
: LGTM! Correct macro definition in NVRTC compilation.The addition of the
PAGED_KV_CACHE_LAYOUT
macro definition ensures the NVRTC compilation tests properly exercise the new layout code paths.
1412-1412
: LGTM! Specific test case for MLA with new layout.The test case properly exercises the new VLLM layout with MLA (Multi-Layer Attention) configuration.
cpp/kernels/xqa/mha_sm90.cu (7)
660-664
: LGTM: Correct conditional tensor map parameters for vLLM layoutThe kernel signature correctly accepts separate tensor maps for K and V caches when using vLLM layout (
PAGED_KV_CACHE_LAYOUT == 1
), while maintaining backward compatibility with the original single tensor map approach.
745-756
: LGTM: Correct tensor map prefetching for both layoutsThe conditional prefetching logic correctly handles both tensor maps for vLLM layout while maintaining the existing single tensor map prefetching for the original layout.
1343-1353
: LGTM: Correct tensor map assignment for K cache loaderThe K cache loader correctly receives
tensorMapVLLMK
for vLLM layout and falls back to the originaltensorMap
for the standard layout.
1418-1428
: LGTM: Correct tensor map assignment for V cache loaderThe V cache loader correctly receives
tensorMapVLLMV
for vLLM layout and falls back to the originaltensorMap
for the standard layout.
1790-1794
: LGTM: Correct baseOffset calculation for vLLM layoutThe baseOffset calculation correctly removes the factor of 2 and K/V offset for vLLM layout since K and V caches are stored separately, eliminating the need for interleaved addressing.
3140-3144
: LGTM: Correct host function signature for vLLM layoutThe host function signature correctly accepts separate K and V cache pointers for vLLM layout, reflecting the architectural difference from the original interleaved cache design.
3221-3263
: LGTM: Correct tensor map creation and kernel launch for both layoutsThe host-side logic correctly creates separate tensor maps for vLLM layout and constructs the appropriate KVCacheList. The conditional kernel launch ensures the correct tensor map arguments are passed based on the cache layout.
PR_Github #13672 [ run ] triggered by Bot |
Signed-off-by: Ransiki Zhang <[email protected]>
Signed-off-by: Ransiki Zhang <[email protected]>
a027481
to
85d5e25
Compare
PR_Github #13672 [ run ] completed with state |
/bot run |
PR_Github #14152 [ run ] triggered by Bot |
PR_Github #14152 [ run ] completed with state |
Signed-off-by: Ransiki Zhang <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
Signed-off-by: Ransiki Zhang <[email protected]>
[None][feat] Add vLLM KV Pool support for XQA kernel
Description
This PR adds support for vLLM's separated KV Pool architecture to the XQA kernel. The main changes include:
Key Modifications:
[bs, heads, tokens, head_dim]
to[bs, tokens, heads, head_dim]
-** modification both for mha.cu and mha_sm90.cu
Test Coverage
Summary by CodeRabbit
Summary by CodeRabbit
New Features
Documentation
Tests