Skip to content

Commit cfed6bf

Browse files
Ransikilancelly
authored andcommitted
[None][feat] Add vLLM KV Pool support for XQA kernel (NVIDIA#6013)
Signed-off-by: Ransiki Zhang <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 48bdf55 commit cfed6bf

File tree

10 files changed

+301
-34
lines changed

10 files changed

+301
-34
lines changed

cpp/kernels/xqa/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ set(CMAKE_CUDA_ARCHITECTURES 89-real 90a-real)
2323
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
2424

2525
option(BUILD_XQA_TESTS "Build XQA tests" OFF)
26+
set(PAGED_KV_CACHE_LAYOUT
27+
"0"
28+
CACHE STRING "Paged KV cache format (0 for XQA Original, 1 for VLLM)")
29+
add_definitions(-DPAGED_KV_CACHE_LAYOUT=${PAGED_KV_CACHE_LAYOUT})
2630

2731
# todo: remove include_directories link_directories and link libs like
2832
# CUDA::cuda_driver CUDA::cudart CUDA::nvrtc
@@ -37,7 +41,7 @@ set(CMAKE_CXX_FLAGS
3741
"${CMAKE_CXX_FLAGS} -march=haswell -Wfatal-errors -Wreturn-type -Wall -Wextra -Wno-unknown-pragmas"
3842
)
3943
set(CMAKE_CUDA_FLAGS
40-
"${CMAKE_CUDA_FLAGS} -allow-unsupported-compiler --expt-relaxed-constexpr -t 0 -res-usage"
44+
"${CMAKE_CUDA_FLAGS} -allow-unsupported-compiler --expt-relaxed-constexpr -t 0 -res-usage -DPAGED_KV_CACHE_LAYOUT=${PAGED_KV_CACHE_LAYOUT}"
4145
)
4246
set(CUDA_PTXAS_FLAGS "-warn-lmem-usage -warn-double-usage -warn-spills"
4347
)# -Werror -v

cpp/kernels/xqa/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ You need to install libgtest-dev and libeigen3-dev before building. To build, us
1616

1717
- ```mkdir build```
1818
- ```cd build```
19-
- ```cmake .. -DCMAKE_BUILD_TYPE=Release```
19+
- ```cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_XQA_TESTS=ON```
2020
- ```cmake --build . -j```
2121

2222
To run unit tests, run `./unitTests`. There are a few runtime options that can be controlled with environment variables:
@@ -25,6 +25,16 @@ To run unit tests, run `./unitTests`. There are a few runtime options that can b
2525
- XQA_USE_QGMMA: On Hopper, we try to use TMA+QGMMA kernel (mha_sm90.cu) by default if possible. To force using mha.cu, set this to 0.
2626
- XQA_NB_SUB_SEQ: The number of CUDA thread blocks used to handle one K/V head. We have reasonable default but if you want to change it manually, use this variable.
2727

28+
## Support for VLLM Paged KV-Cache
29+
When `PAGED_KV_CACHE_LAYOUT=1` is enabled, XQA supports VLLM-style KV pool input with split-wise KV-pool and sequence-first memory layout.
30+
To build and test with this feature enabled, run the following commands:
31+
32+
- ```mkdir build```
33+
- ```cd build```
34+
- ```cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_XQA_TESTS=ON -DPAGED_KV_CACHE_LAYOUT=1```
35+
- ```cmake --build . -j```
36+
- ```./unitTests```
37+
2838
## Generation cubins used in TensorRT-LLM
2939

3040
Run `gen_cubin.py` in the repo workspace.

cpp/kernels/xqa/defines.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
9797
#define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0)
9898
#endif
9999

100+
// Paged KV Cache Format
101+
// 0 - XQA Original
102+
// 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for VLLM/SGLang
103+
#ifdef USE_PAGED_KV_CACHE
104+
#ifndef PAGED_KV_CACHE_LAYOUT
105+
#define PAGED_KV_CACHE_LAYOUT 0
106+
#endif
107+
#endif
108+
100109
// don't modify
101110
#define USE_BEAM_SEARCH (BEAM_WIDTH > 1)
102111

cpp/kernels/xqa/mha.cu

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,17 +1671,33 @@ CUBIN_EXPORT __global__
16711671
uint32_t const dstHeadOffset = 0;
16721672
uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * warpIdx.x;
16731673
#if USE_PAGED_KV_CACHE
1674+
#if PAGED_KV_CACHE_LAYOUT == 1
1675+
uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp;
1676+
1677+
#else
16741678
uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage;
1679+
#endif
16751680
#if BEAM_WIDTH == 1
1681+
#if PAGED_KV_CACHE_LAYOUT == 1
1682+
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
1683+
cacheList.kCacheVLLM, pageIdx, nbKHeads, idxHeadBeg};
1684+
#else
16761685
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
16771686
cacheList.pool, pageIdx, nbKHeads, idxHeadBeg};
1687+
#endif
16781688
#else
1679-
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
1689+
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src
1690+
{
16801691
/*indices=*/smem.gemm0CacheIndir[warpIdx.x].data,
1681-
/*pool=*/cacheList.pool,
1682-
/*pageIndices=*/smem.kCachePages[warpIdx.x].data,
1683-
/*nbKHeads=*/nbKHeads,
1684-
/*offset=*/idxHeadBeg};
1692+
#if PAGED_KV_CACHE_LAYOUT == 1
1693+
/*pool=*/cacheList.kCacheVLLM,
1694+
#else
1695+
/*pool=*/cacheList.pool,
1696+
#endif
1697+
/*pageIndices=*/smem.kCachePages[warpIdx.x].data,
1698+
/*nbKHeads=*/nbKHeads,
1699+
/*offset=*/idxHeadBeg
1700+
};
16851701
#endif
16861702
#else
16871703
uint32_t const idxHeadBeg = cacheKSeqBaseOffset + seqOffset;
@@ -1990,17 +2006,33 @@ CUBIN_EXPORT __global__
19902006
uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * nbXTilesPerXIter * xIter
19912007
+ cacheVTileSeqStride * vIter + cacheVTileSeqLen * warpGrpIdx;
19922008
#if USE_PAGED_KV_CACHE
2009+
#if PAGED_KV_CACHE_LAYOUT == 1
2010+
uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp;
2011+
2012+
#else
19932013
uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage;
2014+
#endif
19942015
#if BEAM_WIDTH == 1
2016+
#if PAGED_KV_CACHE_LAYOUT == 1
2017+
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
2018+
cacheList.vCacheVLLM, pageIdx, nbKHeads, idxHeadBeg};
2019+
#else
19952020
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
19962021
cacheList.pool, pageIdx, nbKHeads, idxHeadBeg};
2022+
#endif
19972023
#else
1998-
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
2024+
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src
2025+
{
19992026
/*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2000-
/*pool=*/cacheList.pool,
2001-
/*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2002-
/*nbKHeads=*/nbKHeads,
2003-
/*offset=*/idxHeadBeg};
2027+
#if PAGED_KV_CACHE_LAYOUT == 1
2028+
/*pool=*/cacheList.vCacheVLLM,
2029+
#else
2030+
/*pool=*/cacheList.pool,
2031+
#endif
2032+
/*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2033+
/*nbKHeads=*/nbKHeads,
2034+
/*offset=*/idxHeadBeg
2035+
};
20042036
#endif
20052037
#else
20062038
uint32_t const idxHeadBeg = cacheVSeqBaseOffset + seqOffset;
@@ -2636,7 +2668,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
26362668
InputHead const* q,
26372669
#endif
26382670
#if USE_PAGED_KV_CACHE
2671+
#if PAGED_KV_CACHE_LAYOUT == 1
2672+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2673+
#else
26392674
GMemCacheHead* pool, // global pool of pages
2675+
#endif
26402676
KVCachePageIndex const*
26412677
kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
26422678
#else
@@ -2702,7 +2738,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
27022738
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
27032739
#if USE_PAGED_KV_CACHE
27042740
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
2741+
#if PAGED_KV_CACHE_LAYOUT == 1
2742+
KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq};
2743+
#else
27052744
KVCacheList<true> const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq};
2745+
#endif
27062746
cudaLaunchKernelEx(&launchCfg, kernel_mha,
27072747
#if SPEC_DEC
27082748
qSeqLen, nbKHeads, headGrpSize, qCuSeqLens,

cpp/kernels/xqa/mha.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
102102
InputHead const* q,
103103
#endif
104104
#if USE_PAGED_KV_CACHE
105+
#if PAGED_KV_CACHE_LAYOUT == 1
106+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
107+
#else
105108
GMemCacheHead* pool, // global pool of pages
109+
#endif
106110
KVCachePageIndex const*
107111
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
108112
#else
@@ -137,7 +141,11 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
137141
InputHead const* q,
138142
#endif
139143
#if USE_PAGED_KV_CACHE
144+
#if PAGED_KV_CACHE_LAYOUT == 1
145+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
146+
#else
140147
GMemCacheHead* pool, // global pool of pages
148+
#endif
141149
KVCachePageIndex const*
142150
kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
143151
#else

cpp/kernels/xqa/mhaUtils.cuh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,17 @@ struct HeadPtr
8080

8181
__device__ inline Head* operator+(uint32_t i) const
8282
{
83+
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
84+
auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage];
85+
return (pageIdx & (1U << 31))
86+
? nullptr
87+
: pool + (tokensPerPage * nbKHeads * pageIdx + offset + (i % tokensPerPage) * nbKHeads);
88+
#else
8389
assert(nbPages == 1 || offset % tokensPerPage == 0);
8490
auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage];
8591
return (pageIdx & (1U << 31)) ? nullptr
8692
: pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage);
93+
#endif
8794
}
8895
};
8996

@@ -239,7 +246,12 @@ struct KVCacheList;
239246
template <>
240247
struct KVCacheList<true>
241248
{
249+
#if PAGED_KV_CACHE_LAYOUT == 1
250+
GMemCacheHead* kCacheVLLM;
251+
GMemCacheHead* vCacheVLLM;
252+
#else
242253
GMemKVCacheHead* pool;
254+
#endif
243255
KVCachePageIndex const* kvCachePageList; // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
244256
SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility)
245257
uint32_t maxNbPagesPerSeq;
@@ -289,9 +301,13 @@ __device__ inline Vec<KVCachePageIndex, nbLoadedPages> getPage(KVCacheList<true>
289301
for (uint32_t i = 0; i < nbLoadedPages; i++)
290302
{
291303
uint32_t const idxPage = idxPageBeg + i;
304+
#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE
305+
ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[maxNbPagesPerSeq * idxReq + idxPage] : kBAD_PAGE_INDEX);
306+
#else
292307
ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq
293308
+ 2 * maxNbPagesPerSeq * idxBeam + maxNbPagesPerSeq * (isK ? 0U : 1U) + idxPage]
294309
: kBAD_PAGE_INDEX);
310+
#endif
295311
}
296312
return ret;
297313
}

0 commit comments

Comments
 (0)