Skip to content

Commit 8adbd75

Browse files
committed
Adding serialization of KV cache events
Signed-off-by: Patrice Castonguay <[email protected]>
1 parent 3a56c1e commit 8adbd75

File tree

14 files changed

+347
-86
lines changed

14 files changed

+347
-86
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ using BlockPtr = std::shared_ptr<KVCacheBlock>;
3636
class KVCacheEventManager
3737
{
3838
public:
39-
explicit KVCacheEventManager(size_t maxKVEventEntries, bool enableAttentionDp = false, std::optional<SizeType32>,
40-
attentionDpRank = std::nullopt, std::optional<SizeType32> attentionDpSize = std::nullopt,
41-
std::optional<SizeTyp32> ppSize);
39+
explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank = std::nullopt,
40+
std::optional<SizeType32> attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5);
4241

4342
~KVCacheEventManager();
4443
KVCacheEventManager(KVCacheEventManager& other) = delete;
@@ -63,6 +62,9 @@ class KVCacheEventManager
6362
// Worker thread which adds events to mEvents.
6463
void worker();
6564

65+
// Thread which exchange events if attentionDP is enabled
66+
void exchangeAttentionDpThread();
67+
6668
private:
6769
// Add an event to mEventQueue
6870
void enqueueEvent(executor::KVCacheEvent&& event);
@@ -71,6 +73,8 @@ class KVCacheEventManager
7173
bool mRun;
7274
/// @brief Worker thread
7375
std::thread mWorkerThread;
76+
/// @brief Exchange thread for attention DP events
77+
std::thread mExchangeAttentionDpThread;
7478

7579
/// @brief The deque of events
7680
std::deque<executor::KVCacheEvent> mEvents;
@@ -93,11 +97,14 @@ class KVCacheEventManager
9397
size_t mMaxSize;
9498
/// @brief An auto-incrementing event id counter
9599
size_t mEventId;
96-
/// @bried Whether this model uses attention DP
97-
/// This is used to determine if we need to gather KV cache events
98-
bool mEnableAttentionDp{false};
99-
std::optional<mMaxSize> mAttentionDpRank;
100-
std::optional<mMaxSize> mAttentionDpSize;
100+
101+
/// @brief Attention DP ranks and size
102+
/// If set, we will exchange KV cache events and accumulate on rank 0
103+
std::optional<SizeType32> mAttentionDpRank;
104+
std::optional<SizeType32> mAttentionDpSize;
105+
106+
/// @brief The period in milliseconds to gather attention DP events across rank
107+
SizeType32 mAttentionDpEventsGatherPeriodMs;
101108
};
102109

103110
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@ class KvCacheConfig
10011001
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
10021002
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
10031003
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
1004+
SizeType32 attentionDpEventsGatherPeriodMs = 5,
10041005
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
10051006

10061007
[[nodiscard]] bool getEnableBlockReuse() const;
@@ -1016,6 +1017,7 @@ class KvCacheConfig
10161017
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
10171018
[[nodiscard]] size_t getEventBufferMaxSize() const;
10181019
[[nodiscard]] bool getUseUvm() const;
1020+
[[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const;
10191021

10201022
void setEnableBlockReuse(bool enableBlockReuse);
10211023
void setEnablePartialReuse(bool enablePartialReuse);
@@ -1030,6 +1032,7 @@ class KvCacheConfig
10301032
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
10311033
void setEventBufferMaxSize(size_t eventBufferMaxSize);
10321034
void setUseUvm(bool useUvm);
1035+
void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventPollPeriodMs);
10331036

10341037
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults);
10351038

@@ -1085,6 +1088,9 @@ class KvCacheConfig
10851088

10861089
/// @brief Whether to use UVM for the KV cache.
10871090
bool mUseUvm;
1091+
1092+
/// @brief The period in milliseconds to poll for attention DP events across rank
1093+
SizeType32 mAttentionDpEventsGatherPeriodMs;
10881094
};
10891095

10901096
/// @brief Configuration class for the runtime perf knobs
@@ -1732,13 +1738,8 @@ using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVC
17321738

17331739
struct KVCacheEvent
17341740
{
1735-
17361741
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize,
1737-
std::optional<SizeType32> attentionDpRank = std::nullopt)
1738-
: eventId{eventId}
1739-
, data{std::move(data)}
1740-
, windowSize{windowSize}
1741-
, attentionDpRank{attentionDpRank} {};
1742+
std::optional<SizeType32> attentionDpRank = std::nullopt);
17421743

17431744
/// @brief The unique id of this event
17441745
IdType eventId;

cpp/include/tensorrt_llm/executor/serialization.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,23 @@ class Serialization
332332
[[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is);
333333

334334
// KVCacheEventDiff
335-
[[nodiscard]] static size_t serializedSize(KVCacheEventDiff<SizeType32> const& data);
336-
static void serialize(KVCacheEventDiff<SizeType32> const& data, std::ostream& os);
337-
[[nodiscard]] static KVCacheEventDiff<SizeType32> deserializeKVCacheEventDiff(std::istream& is);
335+
template <typename T>
336+
[[nodiscard]] static size_t serializedSize(KVCacheEventDiff<T> const& data);
337+
template <typename T>
338+
static void serialize(KVCacheEventDiff<T> const& data, std::ostream& os);
339+
template <typename T>
340+
[[nodiscard]] static KVCacheEventDiff<T> deserializeKVCacheEventDiff(std::istream& is);
338341

339342
// KVCacheUpdateData
340343
[[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data);
341344
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
342345
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);
343346

347+
// UniqueToken
348+
[[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token);
349+
static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os);
350+
[[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is);
351+
344352
// String
345353
static std::string deserializeString(std::istream& is);
346354

cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,48 @@
1818
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
1919
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2020
#include "tensorrt_llm/executor/executor.h"
21+
#include "tensorrt_llm/executor/serialization.h"
22+
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
2123

2224
namespace tle = tensorrt_llm::executor;
2325

2426
namespace tensorrt_llm::batch_manager::kv_cache_manager
2527
{
2628

2729
KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank,
28-
std::optional<SizeType32> attentionDpSize, std::optional<attentionDpSize> ppSize)
30+
std::optional<SizeType32> attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs)
2931
: mRun{true}
3032
, mMaxSize{maxKVEventEntries}
3133
, mEventId{0}
3234
, mAttentionDpRank{attentionDpRank}
3335
, mAttentionDpSize{attentionDpSize}
36+
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
3437
{
38+
3539
TLLM_CHECK(mMaxSize > 0);
3640
if (mAttentionDpRank)
3741
{
3842
TLLM_CHECK_WITH_INFO(
3943
mAttentionDpSize.has_value(), "If attention DP rank is set, the attention DP size must also be set");
40-
TLLM_CHECK(ppSize.has_value());
41-
TLLM_CHECK_WITH_INFO(ppSize.value() == 1, "Events with attention DP are not supported with PP > 1");
4244
TLLM_CHECK_WITH_INFO(mAttentionDpRank.value() < mAttentionDpSize.value(),
4345
"Attention DP rank must be less than attention DP size");
46+
if (mAttentionDpRank.value() == 0)
47+
{
48+
// Rank 0 will gather events from all other ranks
49+
// Need to increase size
50+
mMaxSize *= mAttentionDpSize.value();
51+
}
4452
}
4553
else
4654
{
4755
TLLM_CHECK_WITH_INFO(
48-
!mAttentionDpSize.has_value(), "If attention DP size is set, the attention DP rank must also be set");
56+
!mAttentionDpSize.has_value(), "If attention DP rank is not set, the attention DP size must not be set");
4957
}
50-
// mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this));
5158
mWorkerThread = std::thread([this]() { this->worker(); });
52-
mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpEvents(); });
59+
if (mAttentionDpRank)
60+
{
61+
mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpThread(); });
62+
}
5363
};
5464

5565
KVCacheEventManager::~KVCacheEventManager()
@@ -58,7 +68,10 @@ KVCacheEventManager::~KVCacheEventManager()
5868
mPendingEmptyCV.notify_all();
5969
mEmptyCV.notify_all();
6070
mWorkerThread.join();
61-
mAttentionDpExchangeThread.join();
71+
if (mAttentionDpRank)
72+
{
73+
mExchangeAttentionDpThread.join();
74+
}
6275
}
6376

6477
void KVCacheEventManager::enqueueCreatedEvent(
@@ -84,7 +97,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
8497
for (auto const& block : blocks)
8598
{
8699
data.blocks.emplace_back(block->getHash(), block->getUniqueTokens(), block->getBlockKey().loraTaskId,
87-
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority(), mAttentionDpRank);
100+
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority());
88101
}
89102

90103
enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank});
@@ -100,7 +113,7 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32
100113
}
101114
else
102115
{
103-
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize});
116+
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank});
104117
}
105118
}
106119

@@ -136,28 +149,27 @@ void KVCacheEventManager::flush()
136149
auto eventQueue = std::exchange(mEventQueue, {});
137150
std::unique_lock<std::mutex> lck(mPendingEventsMutex);
138151
mPendingEvents.push_back(std::move(eventQueue));
139-
// If we have events, we need to notify the worker thread to process them
140152
mPendingEmptyCV.notify_one();
141153
}
142154

143155
void KVCacheEventManager::exchangeAttentionDpThread()
144156
{
145-
int32_t pollPeriodMs = 5;
146157
while (true)
147158
{
148-
// If we are not rank 0, send events asynchronously
159+
TLLM_CHECK(mAttentionDpRank);
160+
// If we are not rank 0, send events to rank 0
149161
if (mAttentionDpRank.value() != 0)
150162
{
151163
std::vector<char> serializedEvents;
152164
{
153165
std::unique_lock<std::mutex> lck(mEventsMutex);
154-
serializedEvents = Serialization::serialize(mEvents);
166+
serializedEvents = executor::Serialization::serialize(mEvents);
155167
mEvents.clear();
156168
}
157169
uint64_t vecSize = serializedEvents.size();
158-
COMM_SESSION.send(&vecSize, 1, MpiType::kUINT64, 0, MpiTag::kKVCacheEventSize);
170+
COMM_SESSION.send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize);
159171
COMM_SESSION.send(
160-
serializedEvents.data(), serializedEvents.size(), MpiType::kCHAR, 0, MpiTag::kKVCacheEvent);
172+
serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0, mpi::MpiTag::kKvCacheEvent);
161173
}
162174
else
163175
{
@@ -167,18 +179,18 @@ void KVCacheEventManager::exchangeAttentionDpThread()
167179
while (numRecvs < mAttentionDpSize.value() - 1)
168180
{
169181
MPI_Status probeStatus;
170-
if (COMM_SESSION.iprobe(MPI_ANY_SOURCE, MpiTag::kKVCacheEvent, &status))
182+
if (COMM_SESSION.iprobe(MPI_ANY_SOURCE, mpi::MpiTag::kKvCacheEvent, &probeStatus))
171183
{
172-
uint64_t vecSize;
184+
uint64_t vecSize{0};
173185
COMM_SESSION.recv(
174-
&vecSize, 1, mpi::MpiType::kUINT64, probeStatus.MPI_SOURCE, mpi::MpiTag::kKVCacheEventSize);
186+
&vecSize, 1, mpi::MpiType::kUINT64, probeStatus.MPI_SOURCE, mpi::MpiTag::kKvCacheEventSize);
175187

176188
std::vector<char> serializedEvents(vecSize);
177-
COMM_SESSION.recv(&serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, probeStatus.MPI_SOURCE,
178-
mpi::MpiTag::kKVCacheEvent);
189+
COMM_SESSION.recv(serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, probeStatus.MPI_SOURCE,
190+
mpi::MpiTag::kKvCacheEvent);
179191

180192
// Deserialize the events and add them to the local queue
181-
auto rankEvents = Serialization::deserializeKVCacheEvents(serializedEvents);
193+
auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents);
182194
{
183195
std::unique_lock<std::mutex> lck(mEventsMutex);
184196
mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end());
@@ -187,47 +199,47 @@ void KVCacheEventManager::exchangeAttentionDpThread()
187199
numRecvs++;
188200
}
189201
}
190-
std::this_thread::sleep_for(std::chrono::milliseconds(pollPeriodMs));
202+
std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs));
191203
}
192204
}
205+
}
193206

194-
void KVCacheEventManager::worker()
195-
{
207+
void KVCacheEventManager::worker()
208+
{
196209

197-
while (true)
210+
while (true)
211+
{
212+
std::deque<tle::KVCacheEvent> events;
198213
{
199-
std::deque<tle::KVCacheEvent> events;
214+
std::unique_lock<std::mutex> pendingLock(mPendingEventsMutex);
215+
mPendingEmptyCV.wait(pendingLock, [this] { return !mPendingEvents.empty() || !mRun; });
216+
if (!mRun)
200217
{
201-
std::unique_lock<std::mutex> pendingLock(mPendingEventsMutex);
202-
mPendingEmptyCV.wait(pendingLock, [this] { return !mPendingEvents.empty() || !mRun; });
203-
if (!mRun)
204-
{
205-
return;
206-
}
207-
events = mPendingEvents.front();
208-
mPendingEvents.pop_front();
218+
return;
209219
}
220+
events = mPendingEvents.front();
221+
mPendingEvents.pop_front();
222+
}
210223

211-
std::unique_lock<std::mutex> lck(mEventsMutex);
224+
std::unique_lock<std::mutex> lck(mEventsMutex);
212225

213-
SizeType32 elementsToRemove = mEvents.size() + events.size() - mMaxSize;
226+
SizeType32 elementsToRemove = mEvents.size() + events.size() - mMaxSize;
214227

215-
// First, take elements from mEvents since they are the oldest.
216-
if (elementsToRemove > 0)
217-
{
218-
SizeType32 numRemoved = std::min(static_cast<SizeType32>(mEvents.size()), elementsToRemove);
219-
mEvents.erase(mEvents.begin(), mEvents.begin() + numRemoved);
220-
elementsToRemove -= numRemoved;
221-
TLLM_LOG_WARNING(
222-
"The event queue has reached the max size of %d. Events have been discarded.", mMaxSize);
223-
}
228+
// First, take elements from mEvents since they are the oldest.
229+
if (elementsToRemove > 0)
230+
{
231+
SizeType32 numRemoved = std::min(static_cast<SizeType32>(mEvents.size()), elementsToRemove);
232+
mEvents.erase(mEvents.begin(), mEvents.begin() + numRemoved);
233+
elementsToRemove -= numRemoved;
234+
TLLM_LOG_WARNING("The event queue has reached the max size of %d. Events have been discarded.", mMaxSize);
235+
}
224236

225-
// If there's still too many events, take from the front of the events queue.
226-
mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end());
237+
// If there's still too many events, take from the front of the events queue.
238+
mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end());
227239

228-
// Notify the empty condition variable to wake up any waiting threads
229-
mEmptyCV.notify_one();
230-
}
240+
// Notify the empty condition variable to wake up any waiting threads
241+
mEmptyCV.notify_one();
231242
}
243+
}
232244

233245
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/executor/kvCacheConfig.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
2727
std::optional<size_t> const& hostCacheSize, bool onboardBlocks,
2828
std::optional<FloatType> const& crossKvCacheFraction, std::optional<RetentionPriority> secondaryOffloadMinPriority,
2929
size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm,
30+
SizeType32 attentionDpEventsGatherPeriodMs,
3031
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults)
3132
: mEnableBlockReuse(enableBlockReuse)
3233
, mHostCacheSize(hostCacheSize)
@@ -36,6 +37,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
3637
, mEnablePartialReuse{enablePartialReuse}
3738
, mCopyOnPartialReuse{copyOnPartialReuse}
3839
, mUseUvm{useUvm}
40+
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
3941
{
4042
if (maxTokens)
4143
{
@@ -61,6 +63,8 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
6163
{
6264
fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value());
6365
}
66+
TLLM_CHECK_WITH_INFO(
67+
mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0");
6468
}
6569

6670
bool KvCacheConfig::getEnableBlockReuse() const
@@ -128,6 +132,11 @@ bool KvCacheConfig::getUseUvm() const
128132
return mUseUvm;
129133
}
130134

135+
SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const
136+
{
137+
return mAttentionDpEventsGatherPeriodMs;
138+
}
139+
131140
void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse)
132141
{
133142
mEnableBlockReuse = enableBlockReuse;
@@ -204,6 +213,12 @@ void KvCacheConfig::setUseUvm(bool useUvm)
204213
mUseUvm = useUvm;
205214
}
206215

216+
void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventPollPeriodMs)
217+
{
218+
TLLM_CHECK(attentionDpEventPollPeriodMs > 0);
219+
mAttentionDpEventsGatherPeriodMs = attentionDpEventPollPeriodMs;
220+
}
221+
207222
void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults)
208223
{
209224
if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec)

0 commit comments

Comments
 (0)