@@ -501,7 +501,7 @@ nvinfer1::Dims makeShapeFromCacheState(kv_cache::CacheState const& cacheState)
501
501
template <typename T, int subWarpSize, int vecSizeByte>
502
502
__global__ void splitKVCacheForMLAKernel (T const ** __restrict__ inputBlocks, T** __restrict__ outputCaches,
503
503
int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int DomainPPSize,
504
- int DomainTPSize, int layerNumDomainPP, int kvFactor)
504
+ int DomainTPSize, int DomainCPSize, int layerNumDomainPP, int kvFactor)
505
505
{
506
506
int const subWarpId = threadIdx .x / subWarpSize;
507
507
int const laneId = threadIdx .x % subWarpSize;
@@ -931,6 +931,11 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
931
931
auto targetRankInfo = targetIRanks (destCacheState, selfCacheState, selfIdx);
932
932
TLLM_CHECK (targetRankInfo.mIRanks .size ()
933
933
== (static_cast <size_t >(targetRankInfo.mDomainPPSize * targetRankInfo.mDomainTPSize * targetRankInfo.mDomainCPSize )));
934
+ TLLM_LOG_INFO (" [splitKVCache] targetRankInfo.mIRanks.size(): %d" , targetRankInfo.mIRanks .size ());
935
+ for (auto rank : targetRankInfo.mIRanks )
936
+ {
937
+ TLLM_LOG_INFO (" [splitKVCache] target rank: %d, " , rank);
938
+ }
934
939
auto outputCacheNum = targetRankInfo.mIRanks .size ();
935
940
if (selfCacheState.getAttentionConfig ().mAttentionType == CacheState::AttentionType::kMLA )
936
941
{
@@ -954,6 +959,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
954
959
{
955
960
auto cacheBlockSize = blocks.front ()->getSize ();
956
961
auto cacheDataType = blocks.front ()->getDataType ();
962
+ TLLM_LOG_DEBUG (" [splitKVCache] cacheBlockSize: %zu, cacheDataType: %d" , cacheBlockSize, cacheDataType);
957
963
windowSizes.push_back (window);
958
964
blockNumInwindow.push_back (blocks.size ());
959
965
TLLM_LOG_DEBUG (" window: %d, blockNum: %d blockshape:[%d,%d]" , window, blocks.size (),
@@ -997,7 +1003,6 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
997
1003
998
1004
for (auto layerNum : layersInWindow)
999
1005
{
1000
-
1001
1006
TLLM_CHECK_WITH_INFO (
1002
1007
layerNum % targetRankInfo.mDomainPPSize == 0 , " layerNum in Window must be divisible by domainPPSize" );
1003
1008
}
@@ -1043,6 +1048,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
1043
1048
int const dimsPerHead = selfModelConfig.mSizePerHead ;
1044
1049
int const DomainPPSize = targetRankInfo.mDomainPPSize ;
1045
1050
int const DomainTPSize = targetRankInfo.mDomainTPSize ;
1051
+ int const DomainCPSize = targetRankInfo.mDomainCPSize ;
1046
1052
int const layerNumDomainPP = numLayers / DomainPPSize;
1047
1053
int const headNumDomainTP
1048
1054
= headNum / (DomainTPSize / targetRankInfo.mPeerDupHeadFactor ); // TODO: duplicate head factor
@@ -1051,9 +1057,9 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
1051
1057
constexpr int mlaSubWarpSize = 16 ;
1052
1058
1053
1059
TLLM_LOG_DEBUG (
1054
- " splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, "
1060
+ " splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, domainCPSize: %d, "
1055
1061
" layersPerDomainPP: %d, headsPerDomainTP: %d" ,
1056
- numLayers, headNum, DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP);
1062
+ numLayers, headNum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, headNumDomainTP);
1057
1063
1058
1064
int const remainder = sizePerHead * sizeof (T) % 16 ;
1059
1065
switch (remainder)
@@ -1064,7 +1070,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
1064
1070
{
1065
1071
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 16 ><<<gridDim , blockDimx, 0 , bufferManager.getStream().get()>>> (
1066
1072
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
1067
- inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
1073
+ inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
1068
1074
}
1069
1075
else if (isWindow)
1070
1076
{
@@ -1088,7 +1094,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
1088
1094
{
1089
1095
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 8 ><<<gridDim , blockDimx, 0 , bufferManager.getStream().get()>>> (
1090
1096
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
1091
- inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
1097
+ inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
1092
1098
}
1093
1099
else if (isWindow)
1094
1100
{
@@ -1116,7 +1122,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
1116
1122
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 4 >
1117
1123
<<<gridDim , blockDimx, 0 , bufferManager.getStream().get()>>> (inputBlockPtrsDev, outputCachePtrsDev,
1118
1124
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
1119
- layerNumDomainPP, kvFactor);
1125
+ DomainCPSize, layerNumDomainPP, kvFactor);
1120
1126
}
1121
1127
else if (isWindow)
1122
1128
{
@@ -1149,7 +1155,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
1149
1155
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 2 >
1150
1156
<<<gridDim , blockDimx, 0 , bufferManager.getStream().get()>>> (inputBlockPtrsDev, outputCachePtrsDev,
1151
1157
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
1152
- layerNumDomainPP, kvFactor);
1158
+ DomainCPSize, layerNumDomainPP, kvFactor);
1153
1159
}
1154
1160
else if (isWindow)
1155
1161
{
@@ -1178,7 +1184,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
1178
1184
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 1 >
1179
1185
<<<gridDim , blockDimx, 0 , bufferManager.getStream().get()>>> (inputBlockPtrsDev, outputCachePtrsDev,
1180
1186
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
1181
- layerNumDomainPP, kvFactor);
1187
+ DomainCPSize, layerNumDomainPP, kvFactor);
1182
1188
}
1183
1189
else if (isWindow)
1184
1190
{
0 commit comments