diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 2edfd5f77a3..d95ca1b412b 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques bool CacheFormatter::needSendCache( CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx) { - // int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); if (targetInfo.mDupHeadFactor <= 1) { @@ -90,8 +89,9 @@ bool CacheFormatter::needSendCache( = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup; } + int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0; + return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor); } void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig, @@ -128,11 +128,12 @@ std::vector CacheFormatter::pickRecvConnections( return ret; } TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); + int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; std::vector ret; for (int i = 0; i < targetInfo.mDomainTPSize; i++) { - if (i % targetInfo.mPeerDupHeadFactor == 0) + if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor)) { for (int j = 0; j < targetInfo.mDomainPPSize; j++) { diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 810edd6f451..824a31129f8 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -45,10 +45,12 @@ std::vector MLACacheFormatter::pickRecvConnections( auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); std::vector ret; - // targetInfo , mRanks [tpranks, dpranks] + // targetInfo , mRanks [tpranks, ppranks] + int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; + for (int i = 0; i < targetInfo.mDomainPPSize; i++) { - ret.push_back(i); + ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize); } return ret; } @@ -58,19 +60,24 @@ bool MLACacheFormatter::needSendCache( { int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; + int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP + ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize + : destConfig.getParallelConfig().mTensorParallelism; + int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; + if (selfConfig.getParallelConfig().mEnableAttentionDP) { int selfTPNumInDPGroup = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; - int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP - ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize - : destConfig.getParallelConfig().mTensorParallelism; + int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup; if (selfTPNumInDPGroup <= destTPNumInDPGroup) { return true; } - return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0; + + int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup; + return selfTPrankINDPGroup % dupHeadFactor == destDPRank; } int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP @@ -81,7 +88,8 @@ bool MLACacheFormatter::needSendCache( { return true; } - return selfTpRank % (selfTPNum / destTPNum) == 0; + int dupHeadFactor = selfTPNum / destTPNum; + return selfTpRank % dupHeadFactor == destDPRank; } void MLACacheFormatter::format(TransferSession& session) diff --git a/cpp/tests/batch_manager/cacheTransceiverTest.cpp b/cpp/tests/batch_manager/cacheTransceiverTest.cpp index 99c40f810f6..af916359d0d 100644 --- a/cpp/tests/batch_manager/cacheTransceiverTest.cpp +++ b/cpp/tests/batch_manager/cacheTransceiverTest.cpp @@ -1457,12 +1457,15 @@ TEST(targetTest, CacheStateNODP) verifyContext( /*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true); + verifyContext( /*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); + verifyContext( /*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); + verifyContext( /*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true); verifyContext( @@ -1474,7 +1477,6 @@ TEST(targetTest, CacheStateNODP) contextTP = 2; genTP = 4; - verifyContext( /*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true); verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, @@ -1564,13 +1566,13 @@ TEST(targetTest, CacheStateContextDP) /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, - /*expectNeedSend*/ true); + /*expectNeedSend*/ false); verifyContext( /*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); verifyContext( /*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, - /*expectNeedSend*/ false); + /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);