File tree Expand file tree Collapse file tree 2 files changed +18
-17
lines changed
cpp/tensorrt_llm/executor Expand file tree Collapse file tree 2 files changed +18
-17
lines changed Original file line number Diff line number Diff line change @@ -1205,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig)
1205
1205
totalSize += su::serializedSize (kvCacheConfig.getSecondaryOffloadMinPriority ());
1206
1206
totalSize += su::serializedSize (kvCacheConfig.getEventBufferMaxSize ());
1207
1207
totalSize += su::serializedSize (kvCacheConfig.getUseUvm ());
1208
+ totalSize += su::serializedSize (kvCacheConfig.getAttentionDpEventsGatherPeriodMs ());
1208
1209
return totalSize;
1209
1210
}
1210
1211
Original file line number Diff line number Diff line change @@ -307,6 +307,22 @@ struct get_variant_alternative_type
307
307
}
308
308
};
309
309
310
+ template <typename T>
311
+ T deserialize (std::istream& is);
312
+
313
+ // Helper function to deserialize variant by index using template recursion
314
+ template <typename T, std::size_t ... Is>
315
+ T deserializeVariantByIndex (std::istream& is, std::size_t index, std::index_sequence<Is...> /* indices*/ )
316
+ {
317
+ T result;
318
+ bool found = ((Is == index ? (result = deserialize<std::variant_alternative_t <Is, T>>(is), true ) : false ) || ...);
319
+ if (!found)
320
+ {
321
+ TLLM_THROW (" Invalid variant index during deserialization: " + std::to_string (index));
322
+ }
323
+ return result;
324
+ }
325
+
310
326
// Deserialize
311
327
template <typename T>
312
328
T deserialize (std::istream& is)
@@ -595,23 +611,7 @@ T deserialize(std::istream& is)
595
611
std::size_t index = 0 ;
596
612
is.read (reinterpret_cast <char *>(&index), sizeof (index));
597
613
598
- // TODO: Is there a better way to implement this?
599
- T data;
600
- if (index == 0 )
601
- {
602
- using U = std::variant_alternative_t <0 , T>;
603
- data = deserialize<U>(is);
604
- }
605
- else if (index == 1 )
606
- {
607
- using U = std::variant_alternative_t <1 , T>;
608
- data = deserialize<U>(is);
609
- }
610
- else
611
- {
612
- TLLM_THROW (" Serialization of variant of size > 2 is not supported." );
613
- }
614
- return data;
614
+ return deserializeVariantByIndex<T>(is, index, std::make_index_sequence<std::variant_size_v<T>>{});
615
615
}
616
616
else
617
617
{
You can’t perform that action at this time.
0 commit comments