Skip to content

Commit 9210aa1

Browse files
Merge remote-tracking branch 'public/main' into ytong/releasable_memory
2 parents ee6676a + 48768fd commit 9210aa1

File tree

188 files changed

+5620
-1443
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

188 files changed

+5620
-1443
lines changed

.coderabbit.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ reviews:
2828
related_prs: true
2929
suggested_labels: true
3030
suggested_reviewers: true
31-
auto_assign_reviewers: true
3231
poem: false
3332
auto_review:
3433
drafts: true

.github/pull_request_template.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,21 @@
33
<!--
44
Please write the PR title by following this template:
55
6-
[JIRA ticket/NVBugs ID/GitHub issue][fix/feat/doc/infra/...] \<summary of this PR\>
6+
**[JIRA ticket/NVBugs ID/GitHub issue/None][type] Summary**
77
8-
For example, assume I have a PR to support a new feature about cache manager for JIRA ticket TRTLLM-1000, it would be like:
8+
Valid ticket formats:
9+
- JIRA ticket: [TRTLLM-1234] or [FOOBAR-123] for other FOOBAR project
10+
- NVBugs ID: [https://nvbugs/1234567]
11+
- GitHub issue: [#1234]
12+
- No ticket: [None]
913
10-
[TRTLLM-1000][feat] Support a new feature about cache manager
14+
Valid types (lowercase): [fix], [feat], [doc], [infra], [chore], etc.
1115
12-
Or I have a PR to fix a Llama3 accuracy issue:
13-
14-
[https://nvbugs/1234567][fix] Fix Llama3 accuracy issue
16+
Examples:
17+
- [TRTLLM-1234][feat] Add new feature
18+
- [https://nvbugs/1234567][fix] Fix some bugs
19+
- [#1234][doc] Update documentation
20+
- [None][chore] Minor clean-up
1521
-->
1622

1723
## Description

.github/workflows/pr-check.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
name: PR Checks
17+
18+
on:
19+
pull_request:
20+
types: [opened, edited, synchronize, reopened]
21+
22+
jobs:
23+
check-pr-title:
24+
name: Check PR Title Format
25+
runs-on: ubuntu-latest
26+
steps:
27+
- name: Validate PR Title Format
28+
id: check-pr-title
29+
uses: agenthunt/[email protected]
30+
continue-on-error: true
31+
with:
32+
pr-title-regex: "^(\\[(None|[A-Z0-9]+-[0-9]+|#[0-9]+|https:\\/\\/nvbugs\\/[0-9]+)\\])(\\[[a-z0-9]+\\]) (([^ ].*)?[^ ])$"
33+
pr-body-regex: ""
34+
35+
- name: PR Title Format Guide
36+
if: steps.check-pr-title.outcome == 'failure'
37+
run: |
38+
echo "::error::PR title format check failed."
39+
echo "Expected PR title format:"
40+
echo " [JIRA ticket/NVBugs ID/GitHub issue/None][type] Summary"
41+
echo ""
42+
echo "Valid ticket formats:"
43+
echo " - JIRA ticket: [TRTLLM-1234] or [FOOBAR-123] for other FOOBAR project"
44+
echo " - NVBugs ID: [https://nvbugs/1234567]"
45+
echo " - GitHub issue: [#1234]"
46+
echo " - No ticket: [None]"
47+
echo ""
48+
echo "Valid types (lowercase): [fix], [feat], [doc], [infra], [chore], etc."
49+
echo ""
50+
echo "Examples:"
51+
echo " - [TRTLLM-1234][feat] Add new feature"
52+
echo " - [https://nvbugs/1234567][fix] Fix some bugs"
53+
echo " - [#1234][doc] Update documentation"
54+
echo " - [None][chore] Minor clean-up"
55+
exit 1

3rdparty/xgrammar

Submodule xgrammar updated 173 files

cpp/include/tensorrt_llm/runtime/gptDecoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "tensorrt_llm/runtime/bufferManager.h"
2121
#include "tensorrt_llm/runtime/decodingInput.h"
2222
#include "tensorrt_llm/runtime/decodingOutput.h"
23-
#include "tensorrt_llm/runtime/request.h"
2423
#include "tensorrt_llm/runtime/samplingConfig.h"
2524

2625
#include <NvInferRuntime.h>

cpp/tensorrt_llm/batch_manager/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ set(SRCS
5959

6060
file(GLOB_RECURSE XGRAMMAR_SRCS "${3RDPARTY_DIR}/xgrammar/cpp/*.cc")
6161
list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX
62-
"${3RDPARTY_DIR}/xgrammar/cpp/pybind/.*\\.cc")
62+
"${3RDPARTY_DIR}/xgrammar/cpp/nanobind/.*\\.cc")
6363
list(APPEND SRCS ${XGRAMMAR_SRCS})
6464

6565
if(NOT WIN32)

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ void CacheFormatter::format(TransferSession& session)
166166
auto const numPools = blockManager.getNumPools();
167167
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
168168

169+
auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime;
170+
bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point();
171+
169172
bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1;
170173
if (layerWise)
171174
{
@@ -350,9 +353,14 @@ void CacheFormatter::format(TransferSession& session)
350353
}
351354

352355
auto endTime = std::chrono::steady_clock::now();
356+
double delay = 0.0;
357+
if (recordDelay)
358+
{
359+
delay = std::chrono::duration<double, std::milli>(startTime - lastTokenTime).count();
360+
}
353361
double cacheTransferTime
354362
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
355-
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, cacheTransferTime, size);
363+
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
356364
};
357365

358366
if (connections.size() > 1)
@@ -408,16 +416,19 @@ void CacheFormatter::unformat(TransferSession& session)
408416
{
409417
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
410418
auto const& llmRequest = session.getLlmRequest();
419+
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
411420
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
412-
"Start receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId,
413-
llmRequest.getContextPhaseParams().value().getReqId());
421+
"Start receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, ctxReqId);
414422
auto const& connections = session.getConnections();
415423
auto const& selfConfig = session.getSelfState().getCacheState().value();
416424
auto const& destConfig = session.getOtherState().getCacheState().value();
417425
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
418426
auto& bufferManager = session.getBufferManager();
419427
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest);
420428

429+
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
430+
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
431+
421432
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
422433

423434
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
@@ -546,7 +557,7 @@ void CacheFormatter::unformat(TransferSession& session)
546557
}
547558
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
548559
"End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId,
549-
llmRequest.getContextPhaseParams().value().getReqId());
560+
ctxReqId);
550561
return;
551562
}
552563
// legacyPath: context executor rank only send data to one gen executor rank. it sends multiple cache
@@ -634,6 +645,8 @@ void CacheFormatter::unformat(TransferSession& session)
634645
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
635646
TLLM_CHECK(pickUpConnections.size() > processIdx);
636647
TLLM_CHECK(recvSplitCaches.size() > processIdx);
648+
auto startTime = std::chrono::steady_clock::now();
649+
size_t size = 0;
637650
if (legacyPath)
638651
{
639652
size_t idx = processIdx * blockNum;
@@ -645,6 +658,7 @@ void CacheFormatter::unformat(TransferSession& session)
645658
size_t recvBufferIdx = blockIdx * pickUpConnections.size() + commIdx;
646659
llmRequest.updateKvCacheSize((*recvSplitCaches[recvBufferIdx]).getSizeInBytes());
647660
auto& buffer = recvSplitCaches.at(recvBufferIdx);
661+
size += buffer->getSizeInBytes();
648662
session.recv(pickUpConnections[processIdx], buffer->data(), buffer->getSizeInBytes());
649663
idx++;
650664
}
@@ -655,6 +669,7 @@ void CacheFormatter::unformat(TransferSession& session)
655669
{
656670
llmRequest.updateKvCacheSize((*recvSplitCaches.at(processIdx)).getSizeInBytes());
657671
auto& buffer = recvSplitCaches[processIdx];
672+
size = buffer->getSizeInBytes();
658673
session.recv(pickUpConnections[processIdx], buffer->data(), buffer->getSizeInBytes());
659674
}
660675
else if (bufferCoverTargetNum > 0)
@@ -663,6 +678,7 @@ void CacheFormatter::unformat(TransferSession& session)
663678
+ remainNoCoverTargetNum; // caches.at(recvBufferIdx) is allocated by cudaMalloc
664679
llmRequest.updateKvCacheSize((*recvSplitCaches.at(recvBufferIdx)).getSizeInBytes());
665680
auto& buffer = recvSplitCaches.at(recvBufferIdx);
681+
size = buffer->getSizeInBytes();
666682
session.recv(pickUpConnections[processIdx], buffer->data(), buffer->getSizeInBytes());
667683
bufferManager.copy(*recvSplitCaches.at(recvBufferIdx), *recvSplitCaches[processIdx]);
668684
bufferManager.getStream().synchronize();
@@ -679,6 +695,7 @@ void CacheFormatter::unformat(TransferSession& session)
679695
auto recvSlice = runtime::ITensor::slice(preAllocRecvBuffer, 0, recvSize);
680696
auto copySlice = runtime::ITensor::slice(
681697
recvSplitCaches[processIdx], targetBufferSize - remainRecvSize, recvSize);
698+
size += recvSlice->getSizeInBytes();
682699
llmRequest.updateKvCacheSize((*recvSlice).getSizeInBytes());
683700
session.recv(pickUpConnections[processIdx], recvSlice->data(), recvSlice->getSizeInBytes());
684701
bufferManager.copy(*recvSlice, *copySlice);
@@ -687,6 +704,15 @@ void CacheFormatter::unformat(TransferSession& session)
687704
}
688705
}
689706
}
707+
auto endTime = std::chrono::steady_clock::now();
708+
double delay = 0.0;
709+
if (recordDelay)
710+
{
711+
delay = std::chrono::duration<double, std::milli>(startTime - arrivalTime).count();
712+
}
713+
double cacheTransferTime
714+
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
715+
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
690716
};
691717
if (pickUpConnections.size() > 1)
692718
{
@@ -814,6 +840,8 @@ void CacheFormatter::unformat(TransferSession& session)
814840
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
815841
{
816842
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
843+
TLLM_LOG_WARNING("self: %zu dest %zu", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
844+
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
817845
return false;
818846
}
819847
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ class BaseCacheFormatter
7676

7777
/// @brief Destructor.
7878
virtual ~BaseCacheFormatter() = default;
79+
80+
// TODO: better way for context/generation tagging
81+
void markAsSender(bool isSender)
82+
{
83+
kvCacheMeasureHelper.markAsSender(isSender);
84+
}
85+
86+
protected:
87+
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
7988
};
8089

8190
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
@@ -115,7 +124,6 @@ class CacheFormatter final : public BaseCacheFormatter
115124
private:
116125
BaseKVCacheManager* mCacheManager;
117126
CacheTransBufferManager* mCacheTransBufferManager;
118-
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
119127
};
120128

121129
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(

cpp/tensorrt_llm/batch_manager/dataTransceiver.h

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,24 @@ class DataRequester
269269
class KvCacheMeasureHelper
270270
{
271271
public:
272+
struct Measure
273+
{
274+
double delay; // from last token (ctx) or arrival time (gen), in ms
275+
double duration; // in ms
276+
double bandwidth; // in Gbps
277+
};
278+
272279
KvCacheMeasureHelper(std::string output_path)
273280
: mOutputPath(std::move(output_path))
274281
{
275282
}
276283

277-
void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double duration, size_t size)
284+
void markAsSender(bool isSender)
285+
{
286+
mIsSender = isSender;
287+
}
288+
289+
void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size)
278290
{
279291
auto bandwidth = size * 8 / (duration / 1000) / 1e9;
280292
if (mOutputPath.empty())
@@ -283,15 +295,17 @@ class KvCacheMeasureHelper
283295
}
284296

285297
std::lock_guard<std::mutex> lock(mMutex);
286-
mRequestKVCacheTranfserMeasure[requestId].emplace_back(duration, bandwidth);
298+
mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth});
287299
}
288300

289301
~KvCacheMeasureHelper()
290302
{
291303
if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty())
292304
{
305+
TLLM_CHECK(mIsSender.has_value());
293306
auto rank = mpi::MpiComm::world().getRank();
294-
std::string outFilePath = mOutputPath + "rank_" + std::to_string(rank) + ".txt";
307+
std::string outFilePath
308+
= mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv";
295309
std::ofstream outFile(outFilePath);
296310

297311
TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath);
@@ -301,17 +315,17 @@ class KvCacheMeasureHelper
301315
outFile << "RequestID";
302316
for (size_t i = 0; i < numTransferMeasure; i++)
303317
{
304-
outFile << ",TimeDuration,Bandwidth";
318+
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
305319
}
306320
outFile << '\n';
307321

308322
for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure)
309323
{
310324
outFile << requestID;
311325

312-
for (auto const& [time, bandwidth] : measures)
326+
for (auto const& measure : measures)
313327
{
314-
outFile << "," << time << "," << bandwidth;
328+
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
315329
}
316330
outFile << '\n';
317331
}
@@ -321,9 +335,10 @@ class KvCacheMeasureHelper
321335
}
322336

323337
private:
324-
std::map<LlmRequest::RequestIdType, std::vector<std::pair<double, double>>> mRequestKVCacheTranfserMeasure;
338+
std::map<LlmRequest::RequestIdType, std::vector<Measure>> mRequestKVCacheTranfserMeasure;
325339
std::string mOutputPath;
326340
std::mutex mMutex;
341+
std::optional<bool> mIsSender;
327342
};
328343

329344
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
3939
{
4040
TLLM_CHECK(mManager);
4141
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
42+
mFormatter->markAsSender(true);
4243
}
4344

4445
[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
@@ -136,6 +137,7 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage
136137
TLLM_CHECK(mManager);
137138
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
138139
TLLM_CHECK(mFormatter);
140+
mFormatter->markAsSender(false);
139141
}
140142

141143
TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)

0 commit comments

Comments
 (0)