Skip to content

Commit 431ebff

Browse files
committed
Fixing variant deserialization
Signed-off-by: Patrice Castonguay <[email protected]>
1 parent 8adbd75 commit 431ebff

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig)
12051205
totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority());
12061206
totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize());
12071207
totalSize += su::serializedSize(kvCacheConfig.getUseUvm());
1208+
totalSize += su::serializedSize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs());
12081209
return totalSize;
12091210
}
12101211

cpp/tensorrt_llm/executor/serializeUtils.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,22 @@ struct get_variant_alternative_type
307307
}
308308
};
309309

310+
template <typename T>
311+
T deserialize(std::istream& is);
312+
313+
// Helper function to deserialize variant by index using template recursion
314+
template <typename T, std::size_t... Is>
315+
T deserializeVariantByIndex(std::istream& is, std::size_t index, std::index_sequence<Is...> /*indices*/)
316+
{
317+
T result;
318+
bool found = ((Is == index ? (result = deserialize<std::variant_alternative_t<Is, T>>(is), true) : false) || ...);
319+
if (!found)
320+
{
321+
TLLM_THROW("Invalid variant index during deserialization: " + std::to_string(index));
322+
}
323+
return result;
324+
}
325+
310326
// Deserialize
311327
template <typename T>
312328
T deserialize(std::istream& is)
@@ -595,23 +611,7 @@ T deserialize(std::istream& is)
595611
std::size_t index = 0;
596612
is.read(reinterpret_cast<char*>(&index), sizeof(index));
597613

598-
// TODO: Is there a better way to implement this?
599-
T data;
600-
if (index == 0)
601-
{
602-
using U = std::variant_alternative_t<0, T>;
603-
data = deserialize<U>(is);
604-
}
605-
else if (index == 1)
606-
{
607-
using U = std::variant_alternative_t<1, T>;
608-
data = deserialize<U>(is);
609-
}
610-
else
611-
{
612-
TLLM_THROW("Serialization of variant of size > 2 is not supported.");
613-
}
614-
return data;
614+
return deserializeVariantByIndex<T>(is, index, std::make_index_sequence<std::variant_size_v<T>>{});
615615
}
616616
else
617617
{

0 commit comments

Comments
 (0)