Skip to content

Commit 3a56c1e

Browse files
committed
WIP: Adding separate thread for kv cache events exchange
Signed-off-by: Patrice Castonguay <[email protected]>
1 parent 832bb10 commit 3a56c1e

File tree

6 files changed

+342
-56
lines changed

6 files changed

+342
-56
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,12 @@ struct KVCacheUpdatedData
17021702
explicit KVCacheUpdatedData(IdType blockHash)
17031703
: blockHash{blockHash} {};
17041704

1705+
explicit KVCacheUpdatedData(IdType blockHash, std::optional<KVCacheEventDiff<SizeType32>> cacheLevel,
1706+
std::optional<KVCacheEventDiff<SizeType32>> priority)
1707+
: blockHash{blockHash}
1708+
, cacheLevel{cacheLevel}
1709+
, priority{priority} {};
1710+
17051711
KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue)
17061712
{
17071713
cacheLevel = KVCacheEventDiff<SizeType32>{oldValue, newValue};

cpp/include/tensorrt_llm/executor/serialization.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,45 @@ class Serialization
302302
[[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec(
303303
std::vector<char>& buffer);
304304

305+
// KVCacheEvent deque
306+
[[nodiscard]] static std::vector<char> serialize(std::deque<KVCacheEvent> const& kvCacheEvents);
307+
[[nodiscard]] static std::deque<KVCacheEvent> deserializeKVCacheEvents(std::vector<char>& buffer);
308+
309+
// KVCacheEvent
310+
[[nodiscard]] static size_t serializedSize(KVCacheEvent const& event);
311+
static void serialize(KVCacheEvent const& event, std::ostream& os);
312+
[[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is);
313+
314+
// KVCacheCreatedData
315+
[[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data);
316+
static void serialize(KVCacheCreatedData const& data, std::ostream& os);
317+
[[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is);
318+
319+
// KVCacheStoredData
320+
[[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data);
321+
static void serialize(KVCacheStoredData const& data, std::ostream& os);
322+
[[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is);
323+
324+
// KVCacheStoredBlockData
325+
[[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data);
326+
static void serialize(KVCacheStoredBlockData const& data, std::ostream& os);
327+
[[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is);
328+
329+
// KVCacheRemovedData
330+
[[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data);
331+
static void serialize(KVCacheRemovedData const& data, std::ostream& os);
332+
[[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is);
333+
334+
// KVCacheEventDiff
335+
[[nodiscard]] static size_t serializedSize(KVCacheEventDiff<SizeType32> const& data);
336+
static void serialize(KVCacheEventDiff<SizeType32> const& data, std::ostream& os);
337+
[[nodiscard]] static KVCacheEventDiff<SizeType32> deserializeKVCacheEventDiff(std::istream& is);
338+
339+
// KVCacheUpdateData
340+
[[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data);
341+
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
342+
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);
343+
305344
// String
306345
static std::string deserializeString(std::istream& is);
307346

cpp/include/tensorrt_llm/runtime/utils/mpiTags.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ enum class MpiTag : int
6868
// LogitsThread
6969
kSpecDecLogitsId = 129,
7070
kSpecDecLogitsData = 1025,
71+
72+
// KvCacheEventManager
73+
kKvCacheEventSize = 1026,
74+
kKvCacheEvent = 1027
7175
};
7276

7377
} // namespace tensorrt_llm::mpi

cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional
4949
}
5050
// mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this));
5151
mWorkerThread = std::thread([this]() { this->worker(); });
52+
mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpEvents(); });
5253
};
5354

5455
KVCacheEventManager::~KVCacheEventManager()
@@ -57,6 +58,7 @@ KVCacheEventManager::~KVCacheEventManager()
5758
mPendingEmptyCV.notify_all();
5859
mEmptyCV.notify_all();
5960
mWorkerThread.join();
61+
mAttentionDpExchangeThread.join();
6062
}
6163

6264
void KVCacheEventManager::enqueueCreatedEvent(
@@ -129,82 +131,103 @@ std::deque<tle::KVCacheEvent> KVCacheEventManager::getEvents(std::optional<std::
129131
return std::exchange(mEvents, {});
130132
}
131133

132-
std::vector<char> KVCacheEventManager::serializeEventQueue(std::deque<tle::KVCacheEvent> const& eventQueue)
133-
{
134-
std::vector<char> buffer;
135-
for (auto const& event : eventQueue)
136-
{
137-
auto serialized = event.serialize();
138-
buffer.insert(buffer.end(), serialized.begin(), serialized.end());
139-
}
140-
return buffer;
141-
}
142-
143134
void KVCacheEventManager::flush()
144135
{
145136
auto eventQueue = std::exchange(mEventQueue, {});
146-
147-
// In case of attention DP, we need to gather the events on rank 0
148-
if (mAttentionDpSize && mAttentionDpSize.value() > 1)
149-
{
150-
auto packed = serializeEventQueue(eventQueue);
151-
152-
std::vector<std::vector<char>> rankEventQueues(mAttentionDpSize.value());
153-
serializedRankEventQueues[mAttentionDpRank.value()] = std::move(packed);
154-
155-
// Use COMM_SESSION to fill serializedRankEventQueues on rank 0
156-
157-
// Deserialize the events
158-
eventQueue.clear();
159-
if (mAttentionDpRank == 0)
160-
{
161-
for (auto const& serializedRankEventQueue : serializedRankEventQueues)
162-
{
163-
auto rankEventQueue = deserializeEventQueue(serializedRankEventQueue);
164-
eventQueue.insert(eventQueue.end(), rankEventQueue.begin(), rankEventQueue.end());
165-
}
166-
}
167-
}
168-
169137
std::unique_lock<std::mutex> lck(mPendingEventsMutex);
170138
mPendingEvents.push_back(std::move(eventQueue));
171139
// If we have events, we need to notify the worker thread to process them
172140
mPendingEmptyCV.notify_one();
173141
}
174142

175-
void KVCacheEventManager::worker()
143+
void KVCacheEventManager::exchangeAttentionDpThread()
176144
{
145+
int32_t pollPeriodMs = 5;
177146
while (true)
178147
{
179-
std::deque<tle::KVCacheEvent> events;
148+
// If we are not rank 0, send events asynchronously
149+
if (mAttentionDpRank.value() != 0)
180150
{
181-
std::unique_lock<std::mutex> pendingLock(mPendingEventsMutex);
182-
mPendingEmptyCV.wait(pendingLock, [this] { return !mPendingEvents.empty() || !mRun; });
183-
if (!mRun)
151+
std::vector<char> serializedEvents;
184152
{
185-
return;
153+
std::unique_lock<std::mutex> lck(mEventsMutex);
154+
serializedEvents = Serialization::serialize(mEvents);
155+
mEvents.clear();
186156
}
187-
events = mPendingEvents.front();
188-
mPendingEvents.pop_front();
157+
uint64_t vecSize = serializedEvents.size();
158+
COMM_SESSION.send(&vecSize, 1, MpiType::kUINT64, 0, MpiTag::kKVCacheEventSize);
159+
COMM_SESSION.send(
160+
serializedEvents.data(), serializedEvents.size(), MpiType::kCHAR, 0, MpiTag::kKVCacheEvent);
189161
}
162+
else
163+
{
164+
TLLM_CHECK(mAttentionDpSize.has_value());
165+
// Loop until have received events from all ranks
166+
int32_t numRecvs = 0;
167+
while (numRecvs < mAttentionDpSize.value() - 1)
168+
{
169+
MPI_Status probeStatus;
170+
if (COMM_SESSION.iprobe(MPI_ANY_SOURCE, MpiTag::kKVCacheEvent, &status))
171+
{
172+
uint64_t vecSize;
173+
COMM_SESSION.recv(
174+
&vecSize, 1, mpi::MpiType::kUINT64, probeStatus.MPI_SOURCE, mpi::MpiTag::kKVCacheEventSize);
175+
176+
std::vector<char> serializedEvents(vecSize);
177+
COMM_SESSION.recv(&serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, probeStatus.MPI_SOURCE,
178+
mpi::MpiTag::kKVCacheEvent);
179+
180+
// Deserialize the events and add them to the local queue
181+
auto rankEvents = Serialization::deserializeKVCacheEvents(serializedEvents);
182+
{
183+
std::unique_lock<std::mutex> lck(mEventsMutex);
184+
mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end());
185+
mEmptyCV.notify_one();
186+
}
187+
numRecvs++;
188+
}
189+
}
190+
std::this_thread::sleep_for(std::chrono::milliseconds(pollPeriodMs));
191+
}
192+
}
190193

191-
std::unique_lock<std::mutex> lck(mEventsMutex);
192-
193-
SizeType32 elementsToRemove = mEvents.size() + events.size() - mMaxSize;
194+
void KVCacheEventManager::worker()
195+
{
194196

195-
// First, take elements from mEvents since they are the oldest.
196-
if (elementsToRemove > 0)
197+
while (true)
197198
{
198-
SizeType32 numRemoved = std::min(static_cast<SizeType32>(mEvents.size()), elementsToRemove);
199-
mEvents.erase(mEvents.begin(), mEvents.begin() + numRemoved);
200-
elementsToRemove -= numRemoved;
201-
TLLM_LOG_WARNING("The event queue has reached the max size of %d. Events have been discarded.", mMaxSize);
202-
}
199+
std::deque<tle::KVCacheEvent> events;
200+
{
201+
std::unique_lock<std::mutex> pendingLock(mPendingEventsMutex);
202+
mPendingEmptyCV.wait(pendingLock, [this] { return !mPendingEvents.empty() || !mRun; });
203+
if (!mRun)
204+
{
205+
return;
206+
}
207+
events = mPendingEvents.front();
208+
mPendingEvents.pop_front();
209+
}
203210

204-
// If there's still too many events, take from the front of the events queue.
205-
mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end());
206-
mEmptyCV.notify_one();
211+
std::unique_lock<std::mutex> lck(mEventsMutex);
212+
213+
SizeType32 elementsToRemove = mEvents.size() + events.size() - mMaxSize;
214+
215+
// First, take elements from mEvents since they are the oldest.
216+
if (elementsToRemove > 0)
217+
{
218+
SizeType32 numRemoved = std::min(static_cast<SizeType32>(mEvents.size()), elementsToRemove);
219+
mEvents.erase(mEvents.begin(), mEvents.begin() + numRemoved);
220+
elementsToRemove -= numRemoved;
221+
TLLM_LOG_WARNING(
222+
"The event queue has reached the max size of %d. Events have been discarded.", mMaxSize);
223+
}
224+
225+
// If there's still too many events, take from the front of the events queue.
226+
mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end());
227+
228+
// Notify the empty condition variable to wake up any waiting threads
229+
mEmptyCV.notify_one();
230+
}
207231
}
208-
}
209232

210233
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

0 commit comments

Comments
 (0)