Skip to content

Commit c9077be

Browse files
committed
feat: Add VLLM paged KV cache format support to XQA kernels
1 parent ee3cbb0 commit c9077be

File tree

10 files changed

+301
-34
lines changed

10 files changed

+301
-34
lines changed

cpp/kernels/xqa/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ set(CMAKE_CUDA_ARCHITECTURES 89-real 90a-real)
2323
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
2424

2525
option(BUILD_XQA_TESTS "Build XQA tests" OFF)
26+
set(PAGED_KV_CACHE_LAYOUT
27+
"0"
28+
CACHE STRING "Paged KV cache format (0 for XQA Original, 1 for VLLM)")
29+
add_definitions(-DPAGED_KV_CACHE_LAYOUT=${PAGED_KV_CACHE_LAYOUT})
2630

2731
# todo: remove include_directories link_directories and link libs like
2832
# CUDA::cuda_driver CUDA::cudart CUDA::nvrtc
@@ -37,7 +41,7 @@ set(CMAKE_CXX_FLAGS
3741
"${CMAKE_CXX_FLAGS} -march=haswell -Wfatal-errors -Wreturn-type -Wall -Wextra -Wno-unknown-pragmas"
3842
)
3943
set(CMAKE_CUDA_FLAGS
40-
"${CMAKE_CUDA_FLAGS} -allow-unsupported-compiler --expt-relaxed-constexpr -t 0 -res-usage"
44+
"${CMAKE_CUDA_FLAGS} -allow-unsupported-compiler --expt-relaxed-constexpr -t 0 -res-usage -DPAGED_KV_CACHE_LAYOUT=${PAGED_KV_CACHE_LAYOUT}"
4145
)
4246
set(CUDA_PTXAS_FLAGS "-warn-lmem-usage -warn-double-usage -warn-spills"
4347
)# -Werror -v

cpp/kernels/xqa/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ You need to install libgtest-dev and libeigen3-dev before building. To build, us
1616

1717
- ```mkdir build```
1818
- ```cd build```
19-
- ```cmake .. -DCMAKE_BUILD_TYPE=Release```
19+
- ```cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_XQA_TESTS=ON```
2020
- ```cmake --build . -j```
2121

2222
To run unit tests, run `./unitTests`. There are a few runtime options that can be controlled with environment variables:
@@ -25,6 +25,16 @@ To run unit tests, run `./unitTests`. There are a few runtime options that can b
2525
- XQA_USE_QGMMA: On Hopper, we try to use TMA+QGMMA kernel (mha_sm90.cu) by default if possible. To force using mha.cu, set this to 0.
2626
- XQA_NB_SUB_SEQ: The number of CUDA thread blocks used to handle one K/V head. We have reasonable default but if you want to change it manually, use this variable.
2727

28+
## Support for VLLM Paged KV-Cache:
29+
When `PAGED_KV_CACHE_LAYOUT=1` is enabled, XQA supports VLLM-style KV pool input with split-wise KV-pool and sequence-first memory layout.
30+
To build and test with this feature enabled, run the following commands:
31+
32+
- ```mkdir build```
33+
- ```cd build```
34+
- ```cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_XQA_TESTS=ON -DPAGED_KV_CACHE_LAYOUT=1```
35+
- ```cmake --build . -j```
36+
- ```./unitTests```
37+
2838
## Generation cubins used in TensorRT-LLM
2939

3040
Run `gen_cubin.py` in the repo workspace.

cpp/kernels/xqa/defines.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
9797
#define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0)
9898
#endif
9999

100+
// Paged KV Cache Format
101+
// 0 - XQA Original
102+
// 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for VLLM/SGLang
103+
#ifdef USE_PAGED_KV_CACHE
104+
#ifndef PAGED_KV_CACHE_LAYOUT
105+
#define PAGED_KV_CACHE_LAYOUT 0
106+
#endif
107+
#endif
108+
100109
// don't modify
101110
#define USE_BEAM_SEARCH (BEAM_WIDTH > 1)
102111

cpp/kernels/xqa/mha.cu

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,17 +1672,33 @@ CUBIN_EXPORT __global__
16721672
uint32_t const dstHeadOffset = 0;
16731673
uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * warpIdx.x;
16741674
#if USE_PAGED_KV_CACHE
1675+
#if PAGED_KV_CACHE_LAYOUT == 1
1676+
uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp;
1677+
1678+
#else
16751679
uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage;
1680+
#endif
16761681
#if BEAM_WIDTH == 1
1682+
#if PAGED_KV_CACHE_LAYOUT == 1
1683+
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
1684+
cacheList.kCacheVLLM, pageIdx, nbKHeads, idxHeadBeg};
1685+
#else
16771686
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
16781687
cacheList.pool, pageIdx, nbKHeads, idxHeadBeg};
1688+
#endif
16791689
#else
1680-
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
1690+
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src
1691+
{
16811692
/*indices=*/smem.gemm0CacheIndir[warpIdx.x].data,
1682-
/*pool=*/cacheList.pool,
1683-
/*pageIndices=*/smem.kCachePages[warpIdx.x].data,
1684-
/*nbKHeads=*/nbKHeads,
1685-
/*offset=*/idxHeadBeg};
1693+
#if PAGED_KV_CACHE_LAYOUT == 1
1694+
/*pool=*/cacheList.kCacheVLLM,
1695+
#else
1696+
/*pool=*/cacheList.pool,
1697+
#endif
1698+
/*pageIndices=*/smem.kCachePages[warpIdx.x].data,
1699+
/*nbKHeads=*/nbKHeads,
1700+
/*offset=*/idxHeadBeg
1701+
};
16861702
#endif
16871703
#else
16881704
uint32_t const idxHeadBeg = cacheKSeqBaseOffset + seqOffset;
@@ -1991,17 +2007,33 @@ CUBIN_EXPORT __global__
19912007
uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * nbXTilesPerXIter * xIter
19922008
+ cacheVTileSeqStride * vIter + cacheVTileSeqLen * warpGrpIdx;
19932009
#if USE_PAGED_KV_CACHE
2010+
#if PAGED_KV_CACHE_LAYOUT == 1
2011+
uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp;
2012+
2013+
#else
19942014
uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage;
2015+
#endif
19952016
#if BEAM_WIDTH == 1
2017+
#if PAGED_KV_CACHE_LAYOUT == 1
2018+
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
2019+
cacheList.vCacheVLLM, pageIdx, nbKHeads, idxHeadBeg};
2020+
#else
19962021
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
19972022
cacheList.pool, pageIdx, nbKHeads, idxHeadBeg};
2023+
#endif
19982024
#else
1999-
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
2025+
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src
2026+
{
20002027
/*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2001-
/*pool=*/cacheList.pool,
2002-
/*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2003-
/*nbKHeads=*/nbKHeads,
2004-
/*offset=*/idxHeadBeg};
2028+
#if PAGED_KV_CACHE_LAYOUT == 1
2029+
/*pool=*/cacheList.vCacheVLLM,
2030+
#else
2031+
/*pool=*/cacheList.pool,
2032+
#endif
2033+
/*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2034+
/*nbKHeads=*/nbKHeads,
2035+
/*offset=*/idxHeadBeg
2036+
};
20052037
#endif
20062038
#else
20072039
uint32_t const idxHeadBeg = cacheVSeqBaseOffset + seqOffset;
@@ -2637,7 +2669,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
26372669
InputHead const* q,
26382670
#endif
26392671
#if USE_PAGED_KV_CACHE
2672+
#if PAGED_KV_CACHE_LAYOUT == 1
2673+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2674+
#else
26402675
GMemCacheHead* pool, // global pool of pages
2676+
#endif
26412677
KVCachePageIndex const*
26422678
kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
26432679
#else
@@ -2703,7 +2739,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
27032739
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
27042740
#if USE_PAGED_KV_CACHE
27052741
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
2742+
#if PAGED_KV_CACHE_LAYOUT == 1
2743+
KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq};
2744+
#else
27062745
KVCacheList<true> const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq};
2746+
#endif
27072747
cudaLaunchKernelEx(&launchCfg, kernel_mha,
27082748
#if SPEC_DEC
27092749
qSeqLen, nbKHeads, headGrpSize, qCuSeqLens,

cpp/kernels/xqa/mha.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
102102
InputHead const* q,
103103
#endif
104104
#if USE_PAGED_KV_CACHE
105+
#if PAGED_KV_CACHE_LAYOUT == 1
106+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
107+
#else
105108
GMemCacheHead* pool, // global pool of pages
109+
#endif
106110
KVCachePageIndex const*
107111
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
108112
#else
@@ -137,7 +141,11 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
137141
InputHead const* q,
138142
#endif
139143
#if USE_PAGED_KV_CACHE
144+
#if PAGED_KV_CACHE_LAYOUT == 1
145+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
146+
#else
140147
GMemCacheHead* pool, // global pool of pages
148+
#endif
141149
KVCachePageIndex const*
142150
kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
143151
#else

cpp/kernels/xqa/mhaUtils.cuh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,17 @@ struct HeadPtr
8080

8181
__device__ inline Head* operator+(uint32_t i) const
8282
{
83+
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
84+
auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage];
85+
return (pageIdx & (1U << 31))
86+
? nullptr
87+
: pool + (tokensPerPage * nbKHeads * pageIdx + offset + (i % tokensPerPage) * nbKHeads);
88+
#else
8389
assert(nbPages == 1 || offset % tokensPerPage == 0);
8490
auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage];
8591
return (pageIdx & (1U << 31)) ? nullptr
8692
: pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage);
93+
#endif
8794
}
8895
};
8996

@@ -239,7 +246,12 @@ struct KVCacheList;
239246
template <>
240247
struct KVCacheList<true>
241248
{
249+
#if PAGED_KV_CACHE_LAYOUT == 1
250+
GMemCacheHead* kCacheVLLM;
251+
GMemCacheHead* vCacheVLLM;
252+
#else
242253
GMemKVCacheHead* pool;
254+
#endif
243255
KVCachePageIndex const* kvCachePageList; // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
244256
SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility)
245257
uint32_t maxNbPagesPerSeq;
@@ -289,9 +301,13 @@ __device__ inline Vec<KVCachePageIndex, nbLoadedPages> getPage(KVCacheList<true>
289301
for (uint32_t i = 0; i < nbLoadedPages; i++)
290302
{
291303
uint32_t const idxPage = idxPageBeg + i;
304+
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
305+
ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[maxNbPagesPerSeq * idxReq + idxPage] : kBAD_PAGE_INDEX);
306+
#else
292307
ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq
293308
+ 2 * maxNbPagesPerSeq * idxBeam + maxNbPagesPerSeq * (isK ? 0U : 1U) + idxPage]
294309
: kBAD_PAGE_INDEX);
310+
#endif
295311
}
296312
return ret;
297313
}

0 commit comments

Comments
 (0)