18
18
#include " tensorrt_llm/batch_manager/kvCacheEventManager.h"
19
19
#include " tensorrt_llm/batch_manager/kvCacheManager.h"
20
20
#include " tensorrt_llm/executor/executor.h"
21
+ #include " tensorrt_llm/executor/serialization.h"
22
+ #include " tensorrt_llm/runtime/utils/mpiUtils.h"
21
23
22
24
namespace tle = tensorrt_llm::executor;
23
25
24
26
namespace tensorrt_llm ::batch_manager::kv_cache_manager
25
27
{
26
28
27
29
KVCacheEventManager::KVCacheEventManager (size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank,
28
- std::optional<SizeType32> attentionDpSize, std::optional<attentionDpSize> ppSize )
30
+ std::optional<SizeType32> attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs )
29
31
: mRun {true }
30
32
, mMaxSize {maxKVEventEntries}
31
33
, mEventId {0 }
32
34
, mAttentionDpRank {attentionDpRank}
33
35
, mAttentionDpSize {attentionDpSize}
36
+ , mAttentionDpEventsGatherPeriodMs (attentionDpEventsGatherPeriodMs)
34
37
{
38
+
35
39
TLLM_CHECK (mMaxSize > 0 );
36
40
if (mAttentionDpRank )
37
41
{
38
42
TLLM_CHECK_WITH_INFO (
39
43
mAttentionDpSize .has_value (), " If attention DP rank is set, the attention DP size must also be set" );
40
- TLLM_CHECK (ppSize.has_value ());
41
- TLLM_CHECK_WITH_INFO (ppSize.value () == 1 , " Events with attention DP are not supported with PP > 1" );
42
44
TLLM_CHECK_WITH_INFO (mAttentionDpRank .value () < mAttentionDpSize .value (),
43
45
" Attention DP rank must be less than attention DP size" );
46
+ if (mAttentionDpRank .value () == 0 )
47
+ {
48
+ // Rank 0 will gather events from all other ranks
49
+ // Need to increase size
50
+ mMaxSize *= mAttentionDpSize .value ();
51
+ }
44
52
}
45
53
else
46
54
{
47
55
TLLM_CHECK_WITH_INFO (
48
- !mAttentionDpSize .has_value (), " If attention DP size is set, the attention DP rank must also be set" );
56
+ !mAttentionDpSize .has_value (), " If attention DP rank is not set, the attention DP size must not be set" );
49
57
}
50
- // mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this));
51
58
mWorkerThread = std::thread ([this ]() { this ->worker (); });
52
- mExchangeAttentionDpThread = std::thread ([this ]() { this ->exchangeAttentionDpEvents (); });
59
+ if (mAttentionDpRank )
60
+ {
61
+ mExchangeAttentionDpThread = std::thread ([this ]() { this ->exchangeAttentionDpThread (); });
62
+ }
53
63
};
54
64
55
65
KVCacheEventManager::~KVCacheEventManager ()
@@ -58,7 +68,10 @@ KVCacheEventManager::~KVCacheEventManager()
58
68
mPendingEmptyCV .notify_all ();
59
69
mEmptyCV .notify_all ();
60
70
mWorkerThread .join ();
61
- mAttentionDpExchangeThread .join ();
71
+ if (mAttentionDpRank )
72
+ {
73
+ mExchangeAttentionDpThread .join ();
74
+ }
62
75
}
63
76
64
77
void KVCacheEventManager::enqueueCreatedEvent (
@@ -84,7 +97,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
84
97
for (auto const & block : blocks)
85
98
{
86
99
data.blocks .emplace_back (block->getHash (), block->getUniqueTokens (), block->getBlockKey ().loraTaskId ,
87
- block->isPrimary () ? kPrimaryLevel : kSecondaryLevel , block->getPriority (), mAttentionDpRank );
100
+ block->isPrimary () ? kPrimaryLevel : kSecondaryLevel , block->getPriority ());
88
101
}
89
102
90
103
enqueueEvent ({mEventId ++, data, windowSize, mAttentionDpRank });
@@ -100,7 +113,7 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32
100
113
}
101
114
else
102
115
{
103
- enqueueEvent ({mEventId ++, tle::KVCacheRemovedData{{block->getHash ()}}, windowSize});
116
+ enqueueEvent ({mEventId ++, tle::KVCacheRemovedData{{block->getHash ()}}, windowSize, mAttentionDpRank });
104
117
}
105
118
}
106
119
@@ -136,28 +149,27 @@ void KVCacheEventManager::flush()
136
149
auto eventQueue = std::exchange (mEventQueue , {});
137
150
std::unique_lock<std::mutex> lck (mPendingEventsMutex );
138
151
mPendingEvents .push_back (std::move (eventQueue));
139
- // If we have events, we need to notify the worker thread to process them
140
152
mPendingEmptyCV .notify_one ();
141
153
}
142
154
143
155
void KVCacheEventManager::exchangeAttentionDpThread ()
144
156
{
145
- int32_t pollPeriodMs = 5 ;
146
157
while (true )
147
158
{
148
- // If we are not rank 0, send events asynchronously
159
+ TLLM_CHECK (mAttentionDpRank );
160
+ // If we are not rank 0, send events to rank 0
149
161
if (mAttentionDpRank .value () != 0 )
150
162
{
151
163
std::vector<char > serializedEvents;
152
164
{
153
165
std::unique_lock<std::mutex> lck (mEventsMutex );
154
- serializedEvents = Serialization::serialize (mEvents );
166
+ serializedEvents = executor:: Serialization::serialize (mEvents );
155
167
mEvents .clear ();
156
168
}
157
169
uint64_t vecSize = serializedEvents.size ();
158
- COMM_SESSION.send (&vecSize, 1 , MpiType::kUINT64 , 0 , MpiTag::kKVCacheEventSize );
170
+ COMM_SESSION.send (&vecSize, 1 , mpi:: MpiType::kUINT64 , 0 , mpi:: MpiTag::kKvCacheEventSize );
159
171
COMM_SESSION.send (
160
- serializedEvents.data (), serializedEvents.size (), MpiType::kCHAR , 0 , MpiTag::kKVCacheEvent );
172
+ serializedEvents.data (), serializedEvents.size (), mpi:: MpiType::kCHAR , 0 , mpi:: MpiTag::kKvCacheEvent );
161
173
}
162
174
else
163
175
{
@@ -167,18 +179,18 @@ void KVCacheEventManager::exchangeAttentionDpThread()
167
179
while (numRecvs < mAttentionDpSize .value () - 1 )
168
180
{
169
181
MPI_Status probeStatus;
170
- if (COMM_SESSION.iprobe (MPI_ANY_SOURCE, MpiTag::kKVCacheEvent , &status ))
182
+ if (COMM_SESSION.iprobe (MPI_ANY_SOURCE, mpi:: MpiTag::kKvCacheEvent , &probeStatus ))
171
183
{
172
- uint64_t vecSize;
184
+ uint64_t vecSize{ 0 } ;
173
185
COMM_SESSION.recv (
174
- &vecSize, 1 , mpi::MpiType::kUINT64 , probeStatus.MPI_SOURCE , mpi::MpiTag::kKVCacheEventSize );
186
+ &vecSize, 1 , mpi::MpiType::kUINT64 , probeStatus.MPI_SOURCE , mpi::MpiTag::kKvCacheEventSize );
175
187
176
188
std::vector<char > serializedEvents (vecSize);
177
- COMM_SESSION.recv (& serializedEvents.data (), vecSize, mpi::MpiType::kCHAR , probeStatus.MPI_SOURCE ,
178
- mpi::MpiTag::kKVCacheEvent );
189
+ COMM_SESSION.recv (serializedEvents.data (), vecSize, mpi::MpiType::kCHAR , probeStatus.MPI_SOURCE ,
190
+ mpi::MpiTag::kKvCacheEvent );
179
191
180
192
// Deserialize the events and add them to the local queue
181
- auto rankEvents = Serialization::deserializeKVCacheEvents (serializedEvents);
193
+ auto rankEvents = executor:: Serialization::deserializeKVCacheEvents (serializedEvents);
182
194
{
183
195
std::unique_lock<std::mutex> lck (mEventsMutex );
184
196
mEvents .insert (mEvents .end (), rankEvents.begin (), rankEvents.end ());
@@ -187,47 +199,47 @@ void KVCacheEventManager::exchangeAttentionDpThread()
187
199
numRecvs++;
188
200
}
189
201
}
190
- std::this_thread::sleep_for (std::chrono::milliseconds (pollPeriodMs ));
202
+ std::this_thread::sleep_for (std::chrono::milliseconds (mAttentionDpEventsGatherPeriodMs ));
191
203
}
192
204
}
205
+ }
193
206
194
- void KVCacheEventManager::worker ()
195
- {
207
+ void KVCacheEventManager::worker ()
208
+ {
196
209
197
- while (true )
210
+ while (true )
211
+ {
212
+ std::deque<tle::KVCacheEvent> events;
198
213
{
199
- std::deque<tle::KVCacheEvent> events;
214
+ std::unique_lock<std::mutex> pendingLock (mPendingEventsMutex );
215
+ mPendingEmptyCV .wait (pendingLock, [this ] { return !mPendingEvents .empty () || !mRun ; });
216
+ if (!mRun )
200
217
{
201
- std::unique_lock<std::mutex> pendingLock (mPendingEventsMutex );
202
- mPendingEmptyCV .wait (pendingLock, [this ] { return !mPendingEvents .empty () || !mRun ; });
203
- if (!mRun )
204
- {
205
- return ;
206
- }
207
- events = mPendingEvents .front ();
208
- mPendingEvents .pop_front ();
218
+ return ;
209
219
}
220
+ events = mPendingEvents .front ();
221
+ mPendingEvents .pop_front ();
222
+ }
210
223
211
- std::unique_lock<std::mutex> lck (mEventsMutex );
224
+ std::unique_lock<std::mutex> lck (mEventsMutex );
212
225
213
- SizeType32 elementsToRemove = mEvents .size () + events.size () - mMaxSize ;
226
+ SizeType32 elementsToRemove = mEvents .size () + events.size () - mMaxSize ;
214
227
215
- // First, take elements from mEvents since they are the oldest.
216
- if (elementsToRemove > 0 )
217
- {
218
- SizeType32 numRemoved = std::min (static_cast <SizeType32>(mEvents .size ()), elementsToRemove);
219
- mEvents .erase (mEvents .begin (), mEvents .begin () + numRemoved);
220
- elementsToRemove -= numRemoved;
221
- TLLM_LOG_WARNING (
222
- " The event queue has reached the max size of %d. Events have been discarded." , mMaxSize );
223
- }
228
+ // First, take elements from mEvents since they are the oldest.
229
+ if (elementsToRemove > 0 )
230
+ {
231
+ SizeType32 numRemoved = std::min (static_cast <SizeType32>(mEvents .size ()), elementsToRemove);
232
+ mEvents .erase (mEvents .begin (), mEvents .begin () + numRemoved);
233
+ elementsToRemove -= numRemoved;
234
+ TLLM_LOG_WARNING (" The event queue has reached the max size of %d. Events have been discarded." , mMaxSize );
235
+ }
224
236
225
- // If there's still too many events, take from the front of the events queue.
226
- mEvents .insert (mEvents .end (), events.begin () + std::max (0 , elementsToRemove), events.end ());
237
+ // If there's still too many events, take from the front of the events queue.
238
+ mEvents .insert (mEvents .end (), events.begin () + std::max (0 , elementsToRemove), events.end ());
227
239
228
- // Notify the empty condition variable to wake up any waiting threads
229
- mEmptyCV .notify_one ();
230
- }
240
+ // Notify the empty condition variable to wake up any waiting threads
241
+ mEmptyCV .notify_one ();
231
242
}
243
+ }
232
244
233
245
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
0 commit comments