Skip to content

Commit b6a619e

Browse files
[SYCL] Don't use zstd context across threads. (#19747)
`ZSTDCompressor` holds `zstd` context as its only data members. The idea behind `GetSingletonInstance()` method was to re-use these contexts for subsequent compression and decompressions. Re-using context across (de)compression reduces system memory usage. However, `zstd` contexts are not meant to be used concurrently, therefore, this PR makes `ZSTDCompressor` object thread-local, instead of static. Relevant excerpt from zstd doc (https://facebook.github.io/zstd/zstd_manual.html): > When decompressing many times, > it is recommended to allocate a context only once, > and re-use it for each successive compression operation. > This will make workload friendlier for system's memory. > Use one context per thread for parallel execution.
1 parent 367d5c1 commit b6a619e

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

sycl/source/detail/compression.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ class ZSTDCompressor {
3333

3434
// Get the singleton instance of the ZSTDCompressor class.
3535
static ZSTDCompressor &GetSingletonInstance() {
36-
static ZSTDCompressor instance;
36+
// Use thread_local to ensure that each thread has its own instance.
37+
// This avoids issues with concurrent access to the ZSTD contexts.
38+
thread_local ZSTDCompressor instance;
3739
return instance;
3840
}
3941

sycl/unittests/compression/CompressionTests.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "../thread_safety/ThreadUtils.h"
910
#include <detail/compression.hpp>
1011
#include <sycl/sycl.hpp>
1112

@@ -79,3 +80,36 @@ TEST(CompressionTest, EmptyInputTest) {
7980
std::string decompressedStr((char *)decompressedData.get(), decompressedSize);
8081
ASSERT_EQ(input, decompressedStr);
8182
}
83+
84+
// Test to check for concurrent compression and decompression.
85+
TEST(CompressionTest, ConcurrentCompressionDecompression) {
86+
std::string data = "Concurrent compression and decompression test!";
87+
88+
constexpr size_t ThreadCount = 20;
89+
90+
Barrier b(ThreadCount);
91+
{
92+
auto testCompressDecompress = [&](size_t threadId) {
93+
b.wait();
94+
size_t compressedDataSize = 0;
95+
auto compressedData = ZSTDCompressor::CompressBlob(
96+
data.c_str(), data.size(), compressedDataSize, 3);
97+
98+
ASSERT_NE(compressedData, nullptr);
99+
ASSERT_GT(compressedDataSize, (size_t)0);
100+
101+
size_t decompressedSize = 0;
102+
auto decompressedData = ZSTDCompressor::DecompressBlob(
103+
compressedData.get(), compressedDataSize, decompressedSize);
104+
105+
ASSERT_NE(decompressedData, nullptr);
106+
ASSERT_GT(decompressedSize, (size_t)0);
107+
108+
std::string decompressedStr((char *)decompressedData.get(),
109+
decompressedSize);
110+
ASSERT_EQ(data, decompressedStr);
111+
};
112+
113+
::ThreadPool MPool(ThreadCount, testCompressDecompress);
114+
}
115+
}

0 commit comments

Comments
 (0)