@@ -49,6 +49,7 @@ KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional
49
49
}
50
50
// mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this));
51
51
mWorkerThread = std::thread ([this ]() { this ->worker (); });
52
+ mExchangeAttentionDpThread = std::thread ([this ]() { this ->exchangeAttentionDpEvents (); });
52
53
};
53
54
54
55
KVCacheEventManager::~KVCacheEventManager ()
@@ -57,6 +58,7 @@ KVCacheEventManager::~KVCacheEventManager()
57
58
mPendingEmptyCV .notify_all ();
58
59
mEmptyCV .notify_all ();
59
60
mWorkerThread .join ();
61
+ mAttentionDpExchangeThread .join ();
60
62
}
61
63
62
64
void KVCacheEventManager::enqueueCreatedEvent (
@@ -129,82 +131,103 @@ std::deque<tle::KVCacheEvent> KVCacheEventManager::getEvents(std::optional<std::
129
131
return std::exchange (mEvents , {});
130
132
}
131
133
132
- std::vector<char > KVCacheEventManager::serializeEventQueue (std::deque<tle::KVCacheEvent> const & eventQueue)
133
- {
134
- std::vector<char > buffer;
135
- for (auto const & event : eventQueue)
136
- {
137
- auto serialized = event.serialize ();
138
- buffer.insert (buffer.end (), serialized.begin (), serialized.end ());
139
- }
140
- return buffer;
141
- }
142
-
143
134
void KVCacheEventManager::flush ()
144
135
{
145
136
auto eventQueue = std::exchange (mEventQueue , {});
146
-
147
- // In case of attention DP, we need to gather the events on rank 0
148
- if (mAttentionDpSize && mAttentionDpSize .value () > 1 )
149
- {
150
- auto packed = serializeEventQueue (eventQueue);
151
-
152
- std::vector<std::vector<char >> rankEventQueues (mAttentionDpSize .value ());
153
- serializedRankEventQueues[mAttentionDpRank .value ()] = std::move (packed);
154
-
155
- // Use COMM_SESSION to fill serializedRankEventQueues on rank 0
156
-
157
- // Deserialize the events
158
- eventQueue.clear ();
159
- if (mAttentionDpRank == 0 )
160
- {
161
- for (auto const & serializedRankEventQueue : serializedRankEventQueues)
162
- {
163
- auto rankEventQueue = deserializeEventQueue (serializedRankEventQueue);
164
- eventQueue.insert (eventQueue.end (), rankEventQueue.begin (), rankEventQueue.end ());
165
- }
166
- }
167
- }
168
-
169
137
std::unique_lock<std::mutex> lck (mPendingEventsMutex );
170
138
mPendingEvents .push_back (std::move (eventQueue));
171
139
// If we have events, we need to notify the worker thread to process them
172
140
mPendingEmptyCV .notify_one ();
173
141
}
174
142
175
- void KVCacheEventManager::worker ()
143
+ void KVCacheEventManager::exchangeAttentionDpThread ()
176
144
{
145
+ int32_t pollPeriodMs = 5 ;
177
146
while (true )
178
147
{
179
- std::deque<tle::KVCacheEvent> events;
148
+ // If we are not rank 0, send events asynchronously
149
+ if (mAttentionDpRank .value () != 0 )
180
150
{
181
- std::unique_lock<std::mutex> pendingLock (mPendingEventsMutex );
182
- mPendingEmptyCV .wait (pendingLock, [this ] { return !mPendingEvents .empty () || !mRun ; });
183
- if (!mRun )
151
+ std::vector<char > serializedEvents;
184
152
{
185
- return ;
153
+ std::unique_lock<std::mutex> lck (mEventsMutex );
154
+ serializedEvents = Serialization::serialize (mEvents );
155
+ mEvents .clear ();
186
156
}
187
- events = mPendingEvents .front ();
188
- mPendingEvents .pop_front ();
157
+ uint64_t vecSize = serializedEvents.size ();
158
+ COMM_SESSION.send (&vecSize, 1 , MpiType::kUINT64 , 0 , MpiTag::kKVCacheEventSize );
159
+ COMM_SESSION.send (
160
+ serializedEvents.data (), serializedEvents.size (), MpiType::kCHAR , 0 , MpiTag::kKVCacheEvent );
189
161
}
162
+ else
163
+ {
164
+ TLLM_CHECK (mAttentionDpSize .has_value ());
165
+ // Loop until have received events from all ranks
166
+ int32_t numRecvs = 0 ;
167
+ while (numRecvs < mAttentionDpSize .value () - 1 )
168
+ {
169
+ MPI_Status probeStatus;
170
+ if (COMM_SESSION.iprobe (MPI_ANY_SOURCE, MpiTag::kKVCacheEvent , &status))
171
+ {
172
+ uint64_t vecSize;
173
+ COMM_SESSION.recv (
174
+ &vecSize, 1 , mpi::MpiType::kUINT64 , probeStatus.MPI_SOURCE , mpi::MpiTag::kKVCacheEventSize );
175
+
176
+ std::vector<char > serializedEvents (vecSize);
177
+ COMM_SESSION.recv (&serializedEvents.data (), vecSize, mpi::MpiType::kCHAR , probeStatus.MPI_SOURCE ,
178
+ mpi::MpiTag::kKVCacheEvent );
179
+
180
+ // Deserialize the events and add them to the local queue
181
+ auto rankEvents = Serialization::deserializeKVCacheEvents (serializedEvents);
182
+ {
183
+ std::unique_lock<std::mutex> lck (mEventsMutex );
184
+ mEvents .insert (mEvents .end (), rankEvents.begin (), rankEvents.end ());
185
+ mEmptyCV .notify_one ();
186
+ }
187
+ numRecvs++;
188
+ }
189
+ }
190
+ std::this_thread::sleep_for (std::chrono::milliseconds (pollPeriodMs));
191
+ }
192
+ }
190
193
191
- std::unique_lock<std::mutex> lck (mEventsMutex );
192
-
193
- SizeType32 elementsToRemove = mEvents .size () + events.size () - mMaxSize ;
194
+ void KVCacheEventManager::worker ()
195
+ {
194
196
195
- // First, take elements from mEvents since they are the oldest.
196
- if (elementsToRemove > 0 )
197
+ while (true )
197
198
{
198
- SizeType32 numRemoved = std::min (static_cast <SizeType32>(mEvents .size ()), elementsToRemove);
199
- mEvents .erase (mEvents .begin (), mEvents .begin () + numRemoved);
200
- elementsToRemove -= numRemoved;
201
- TLLM_LOG_WARNING (" The event queue has reached the max size of %d. Events have been discarded." , mMaxSize );
202
- }
199
+ std::deque<tle::KVCacheEvent> events;
200
+ {
201
+ std::unique_lock<std::mutex> pendingLock (mPendingEventsMutex );
202
+ mPendingEmptyCV .wait (pendingLock, [this ] { return !mPendingEvents .empty () || !mRun ; });
203
+ if (!mRun )
204
+ {
205
+ return ;
206
+ }
207
+ events = mPendingEvents .front ();
208
+ mPendingEvents .pop_front ();
209
+ }
203
210
204
- // If there's still too many events, take from the front of the events queue.
205
- mEvents .insert (mEvents .end (), events.begin () + std::max (0 , elementsToRemove), events.end ());
206
- mEmptyCV .notify_one ();
211
+ std::unique_lock<std::mutex> lck (mEventsMutex );
212
+
213
+ SizeType32 elementsToRemove = mEvents .size () + events.size () - mMaxSize ;
214
+
215
+ // First, take elements from mEvents since they are the oldest.
216
+ if (elementsToRemove > 0 )
217
+ {
218
+ SizeType32 numRemoved = std::min (static_cast <SizeType32>(mEvents .size ()), elementsToRemove);
219
+ mEvents .erase (mEvents .begin (), mEvents .begin () + numRemoved);
220
+ elementsToRemove -= numRemoved;
221
+ TLLM_LOG_WARNING (
222
+ " The event queue has reached the max size of %d. Events have been discarded." , mMaxSize );
223
+ }
224
+
225
+ // If there's still too many events, take from the front of the events queue.
226
+ mEvents .insert (mEvents .end (), events.begin () + std::max (0 , elementsToRemove), events.end ());
227
+
228
+ // Notify the empty condition variable to wake up any waiting threads
229
+ mEmptyCV .notify_one ();
230
+ }
207
231
}
208
- }
209
232
210
233
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
0 commit comments