Skip to content

Commit e91f63b

Browse files
committed
save some more changes - domainCPSize to be put to use in split kernel
1 parent 2663f0e commit e91f63b

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ void MLACacheFormatter::format(TransferSession& session)
162162
auto& outputSplitCaches = std::get<0>(result);
163163
auto& bufferCoverTargetNum = std::get<1>(result);
164164
auto& onlyUseDynamicBuffer = std::get<2>(result);
165-
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
166-
if (agentConnnecion != nullptr)
165+
auto* agentConnnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
166+
if (agentConnnection != nullptr)
167167
{
168168
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize * cPDomainSize, "Agent need all buffer pre-allocated");
169169
TLLM_CHECK(onlyUseDynamicBuffer == false);

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ nvinfer1::Dims makeShapeFromCacheState(kv_cache::CacheState const& cacheState)
501501
template <typename T, int subWarpSize, int vecSizeByte>
502502
__global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T** __restrict__ outputCaches,
503503
int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int DomainPPSize,
504-
int DomainTPSize, int layerNumDomainPP, int kvFactor)
504+
int DomainTPSize, int DomainCPSize, int layerNumDomainPP, int kvFactor)
505505
{
506506
int const subWarpId = threadIdx.x / subWarpSize;
507507
int const laneId = threadIdx.x % subWarpSize;
@@ -931,6 +931,11 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
931931
auto targetRankInfo = targetIRanks(destCacheState, selfCacheState, selfIdx);
932932
TLLM_CHECK(targetRankInfo.mIRanks.size()
933933
== (static_cast<size_t>(targetRankInfo.mDomainPPSize * targetRankInfo.mDomainTPSize * targetRankInfo.mDomainCPSize)));
934+
TLLM_LOG_INFO("[splitKVCache] targetRankInfo.mIRanks.size(): %d", targetRankInfo.mIRanks.size());
935+
for (auto rank : targetRankInfo.mIRanks)
936+
{
937+
TLLM_LOG_INFO("[splitKVCache] target rank: %d, ", rank);
938+
}
934939
auto outputCacheNum = targetRankInfo.mIRanks.size();
935940
if (selfCacheState.getAttentionConfig().mAttentionType == CacheState::AttentionType::kMLA)
936941
{
@@ -954,6 +959,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
954959
{
955960
auto cacheBlockSize = blocks.front()->getSize();
956961
auto cacheDataType = blocks.front()->getDataType();
962+
TLLM_LOG_DEBUG("[splitKVCache] cacheBlockSize: %zu, cacheDataType: %d", cacheBlockSize, cacheDataType);
957963
windowSizes.push_back(window);
958964
blockNumInwindow.push_back(blocks.size());
959965
TLLM_LOG_DEBUG("window: %d, blockNum: %d blockshape:[%d,%d]", window, blocks.size(),
@@ -997,7 +1003,6 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
9971003

9981004
for (auto layerNum : layersInWindow)
9991005
{
1000-
10011006
TLLM_CHECK_WITH_INFO(
10021007
layerNum % targetRankInfo.mDomainPPSize == 0, "layerNum in Window must be divisible by domainPPSize");
10031008
}
@@ -1043,6 +1048,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
10431048
int const dimsPerHead = selfModelConfig.mSizePerHead;
10441049
int const DomainPPSize = targetRankInfo.mDomainPPSize;
10451050
int const DomainTPSize = targetRankInfo.mDomainTPSize;
1051+
int const DomainCPSize = targetRankInfo.mDomainCPSize;
10461052
int const layerNumDomainPP = numLayers / DomainPPSize;
10471053
int const headNumDomainTP
10481054
= headNum / (DomainTPSize / targetRankInfo.mPeerDupHeadFactor); // TODO: duplicate head factor
@@ -1051,9 +1057,9 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
10511057
constexpr int mlaSubWarpSize = 16;
10521058

10531059
TLLM_LOG_DEBUG(
1054-
"splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, "
1060+
"splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, domainCPSize: %d, "
10551061
"layersPerDomainPP: %d, headsPerDomainTP: %d",
1056-
numLayers, headNum, DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP);
1062+
numLayers, headNum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, headNumDomainTP);
10571063

10581064
int const remainder = sizePerHead * sizeof(T) % 16;
10591065
switch (remainder)
@@ -1064,7 +1070,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
10641070
{
10651071
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 16><<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(
10661072
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
1067-
inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
1073+
inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
10681074
}
10691075
else if (isWindow)
10701076
{
@@ -1088,7 +1094,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
10881094
{
10891095
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 8><<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(
10901096
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
1091-
inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
1097+
inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
10921098
}
10931099
else if (isWindow)
10941100
{
@@ -1116,7 +1122,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
11161122
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 4>
11171123
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
11181124
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
1119-
layerNumDomainPP, kvFactor);
1125+
DomainCPSize, layerNumDomainPP, kvFactor);
11201126
}
11211127
else if (isWindow)
11221128
{
@@ -1149,7 +1155,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
11491155
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 2>
11501156
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
11511157
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
1152-
layerNumDomainPP, kvFactor);
1158+
DomainCPSize, layerNumDomainPP, kvFactor);
11531159
}
11541160
else if (isWindow)
11551161
{
@@ -1178,7 +1184,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
11781184
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 1>
11791185
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
11801186
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
1181-
layerNumDomainPP, kvFactor);
1187+
DomainCPSize, layerNumDomainPP, kvFactor);
11821188
}
11831189
else if (isWindow)
11841190
{

cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,9 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
523523
#if ENABLE_MULTI_DEVICE
524524
tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE);
525525

526-
if (tensorrt_llm::mpi::MpiComm::world().getSize() != 8)
526+
if (tensorrt_llm::mpi::MpiComm::world().getSize() != 4)
527527
{
528-
GTEST_SKIP() << "mpirun with procs=8 is required to run this test.";
528+
GTEST_SKIP() << "mpirun with procs=4 is required to run this test.";
529529
}
530530
int worldSize = tensorrt_llm::mpi::MpiComm::world().getSize();
531531
int worldRank = tensorrt_llm::mpi::MpiComm::world().getRank();
@@ -998,7 +998,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
998998
// Debug print with rank information for MPI debugging (KEY values)
999999
if (TARGET_RANK == -1 || tensorrt_llm::mpi::MpiComm::world().getRank() == TARGET_RANK)
10001000
{
1001-
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(),
1001+
TLLM_LOG_INFO(tensorrt_llm::mpi::MpiComm::world().getRank(),
10021002
"[RANK %d] [fillBlockData::key] blockId=%d, layer=%d->%d, head=%d->%d, token=%d->%d, hidden=%d, "
10031003
"keyIdx=%zu, value=%s, dataType=%d",
10041004
tensorrt_llm::mpi::MpiComm::world().getRank(),
@@ -1024,7 +1024,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
10241024
// Debug print with rank information for MPI debugging (VALUE values)
10251025
if (TARGET_RANK == -1 || tensorrt_llm::mpi::MpiComm::world().getRank() == TARGET_RANK)
10261026
{
1027-
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(),
1027+
TLLM_LOG_INFO(tensorrt_llm::mpi::MpiComm::world().getRank(),
10281028
"[RANK %d] [fillBlockData::value] blockId=%d, layer=%d->%d, head=%d->%d, token=%d->%d, hidden=%d, "
10291029
"valueIdx=%zu, value=%s, dataType=%d",
10301030
tensorrt_llm::mpi::MpiComm::world().getRank(),
@@ -1097,7 +1097,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
10971097
// Debug print with rank information for MPI debugging (KEY values)
10981098
if (TARGET_RANK == -1 || tensorrt_llm::mpi::MpiComm::world().getRank() == TARGET_RANK)
10991099
{
1100-
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(),
1100+
TLLM_LOG_INFO(tensorrt_llm::mpi::MpiComm::world().getRank(),
11011101
"[RANK %d] [verifyBlockData::key] blockId=%d, layer=%d->%d, head=%d->%d, token=%d->%d, hidden=%d, "
11021102
"keyIdx=%zu, value=%s, dataType=%d",
11031103
tensorrt_llm::mpi::MpiComm::world().getRank(),
@@ -1122,7 +1122,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
11221122
// Debug print with rank information for MPI debugging (VALUE values)
11231123
if (TARGET_RANK == -1 || tensorrt_llm::mpi::MpiComm::world().getRank() == TARGET_RANK)
11241124
{
1125-
TLLM_LOG_DEBUG(tensorrt_llm::mpi::MpiComm::world().getRank(),
1125+
TLLM_LOG_INFO(tensorrt_llm::mpi::MpiComm::world().getRank(),
11261126
"[RANK %d] [verifyBlockData::value] blockId=%d, layer=%d->%d, head=%d->%d, token=%d->%d, hidden=%d, "
11271127
"valueIdx=%zu, value=%s, dataType=%d",
11281128
tensorrt_llm::mpi::MpiComm::world().getRank(),

0 commit comments

Comments
 (0)