Skip to content

Commit 1c15c44

Browse files
committed
Add mutex around cache lookup
1 parent f753513 commit 1c15c44

File tree

7 files changed

+36
-26
lines changed

7 files changed

+36
-26
lines changed

examples/fft_conv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
7373
{
7474
MATX_ENTER_HANDLER();
7575
using complex = cuda::std::complex<float>;
76-
cudaExecutor exec{};
7776

7877
index_t signal_size = 1ULL << 16;
7978
index_t filter_size = 16;
@@ -87,6 +86,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
8786
cudaEvent_t start, stop;
8887
cudaEventCreate(&start);
8988
cudaEventCreate(&stop);
89+
cudaExecutor exec{stream};
9090

9191
// Create time domain buffers
9292
auto sig_time = make_tensor<complex>({batches, signal_size});

include/matx/core/cache.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <functional>
3737
#include <optional>
3838
#include <any>
39+
#include <shared_mutex>
3940
#include <unordered_map>
4041
#include <cuda/atomic>
4142

@@ -50,6 +51,7 @@ using CacheId = uint64_t;
5051
__attribute__ ((visibility ("default")))
5152
#endif
5253
inline cuda::std::atomic<CacheId> CacheIdCounter{0};
54+
inline std::shared_mutex cache_mtx; ///< Mutex protecting updates from map
5355

5456
template<typename CacheType>
5557
__attribute__ ((visibility ("default")))
@@ -83,6 +85,8 @@ class matxCache_t {
8385
*/
8486
template <typename CacheType>
8587
void Clear(const CacheId &id) {
88+
[[maybe_unused]] std::unique_lock lck(cache_mtx);
89+
8690
auto el = cache.find(id);
8791
MATX_ASSERT_STR(el != cache.end(), matxInvalidType, "Cache type not found");
8892

@@ -91,6 +95,8 @@ class matxCache_t {
9195

9296
template <typename CacheType, typename InParams, typename MakeFun, typename ExecFun>
9397
void LookupAndExec(const CacheId &id, const InParams &params, const MakeFun &mfun, const ExecFun &efun) {
98+
[[maybe_unused]] std::unique_lock lck(cache_mtx);
99+
94100
// Create named cache if it doesn't exist
95101
auto el = cache.find(id);
96102
if (el == cache.end()) {

include/matx/transforms/chol/chol_cuda.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {
103103
MATX_STATIC_ASSERT_STR(!is_half_v<T1>, matxInvalidType, "Cholesky solver does not support half precision");
104104
MATX_STATIC_ASSERT_STR((std::is_same_v<T1, typename OutTensor_t::value_type>), matxInavlidType, "Input and Output types must match");
105105

106-
params = GetCholParams(a, uplo);
107-
params.exec = exec;
106+
params = GetCholParams(a, uplo, exec);
107+
108108
this->GetWorkspaceSize();
109109
this->AllocateWorkspace(params.batch_size, false, exec);
110110
}
@@ -120,13 +120,15 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {
120120
}
121121

122122
static DnCholCUDAParams_t GetCholParams(const ATensor &a,
123-
cublasFillMode_t uplo)
123+
cublasFillMode_t uplo,
124+
const cudaExecutor &exec)
124125
{
125126
DnCholCUDAParams_t params;
126127
params.batch_size = GetNumBatches(a);
127128
params.n = a.Size(RANK - 1);
128129
params.A = a.Data();
129130
params.uplo = uplo;
131+
params.exec = exec;
130132
params.dtype = TypeToInt<T1>();
131133

132134
return params;
@@ -298,7 +300,7 @@ void chol_impl(OutputTensor &&out, const ATensor &a,
298300
cublasFillMode_t uplo_cusolver = (uplo == SolverFillMode::UPPER)? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
299301

300302
// Get parameters required by these tensors
301-
auto params = detail::matxDnCholCUDAPlan_t<OutputTensor, decltype(tmp_out)>::GetCholParams(tmp_out, uplo_cusolver);
303+
auto params = detail::matxDnCholCUDAPlan_t<OutputTensor, decltype(tmp_out)>::GetCholParams(tmp_out, uplo_cusolver, exec);
302304

303305
using cache_val_type = detail::matxDnCholCUDAPlan_t<OutputTensor, decltype(tmp_out)>;
304306
detail::GetCache().LookupAndExec<detail::chol_cuda_cache_t>(

include/matx/transforms/eig/eig_cuda.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
115115
MATX_STATIC_ASSERT_STR(!is_complex_v<T2>, matxInvalidType, "W type must be real");
116116
MATX_STATIC_ASSERT_STR((std::is_same_v<typename inner_op_type_t<T1>::type, T2>), matxInvalidType, "Out and W inner types must match");
117117

118-
params = GetEigParams(w, a, jobz, uplo);
119-
params.exec = exec;
118+
params = GetEigParams(w, a, jobz, uplo, exec);
120119
this->GetWorkspaceSize();
121120
#if CUSOLVER_VERSION > 11701 || (CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >= 2)
122121
this->AllocateWorkspace(params.batch_size, true, exec);
@@ -150,7 +149,8 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
150149
static DnEigCUDAParams_t GetEigParams(WTensor &w,
151150
const ATensor &a,
152151
cusolverEigMode_t jobz,
153-
cublasFillMode_t uplo)
152+
cublasFillMode_t uplo,
153+
const cudaExecutor &exec)
154154
{
155155
DnEigCUDAParams_t params;
156156
params.batch_size = GetNumBatches(a);
@@ -159,6 +159,8 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
159159
params.W = w.Data();
160160
params.jobz = jobz;
161161
params.uplo = uplo;
162+
params.exec = exec;
163+
162164
params.dtype = TypeToInt<T1>();
163165

164166
return params;
@@ -342,7 +344,7 @@ void eig_impl(OutputTensor &&out, WTensor &&w,
342344

343345
// Get parameters required by these tensors
344346
auto params = detail::matxDnEigCUDAPlan_t<OutputTensor, decltype(w_new), decltype(a_new)>::
345-
GetEigParams(w_new, tv, jobz_cusolver, uplo_cusolver);
347+
GetEigParams(w_new, tv, jobz_cusolver, uplo_cusolver, exec);
346348

347349
// Get cache or new eigen plan if it doesn't exist
348350
using cache_val_type = detail::matxDnEigCUDAPlan_t<OutputTensor, decltype(w_new), decltype(a_new)>;

include/matx/transforms/lu/lu_cuda.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {
106106
MATX_STATIC_ASSERT_STR((std::is_same_v<T1, typename OutTensor_t::value_type>), matxInavlidType, "Input and Output types must match");
107107
MATX_STATIC_ASSERT_STR((std::is_same_v<T2, int64_t>), matxInavlidType, "Pivot tensor type must be int64_t");
108108

109-
params = GetLUParams(piv, a);
110-
params.exec = exec;
109+
params = GetLUParams(piv, a, exec);
111110
this->GetWorkspaceSize();
112111
this->AllocateWorkspace(params.batch_size, false, exec);
113112
}
@@ -123,7 +122,8 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {
123122
}
124123

125124
static DnLUCUDAParams_t GetLUParams(PivotTensor &piv,
126-
const ATensor &a) noexcept
125+
const ATensor &a,
126+
const cudaExecutor &exec) noexcept
127127
{
128128
DnLUCUDAParams_t params;
129129
params.batch_size = GetNumBatches(a);
@@ -132,7 +132,7 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {
132132
params.A = a.Data();
133133
params.piv = piv.Data();
134134
params.dtype = TypeToInt<T1>();
135-
135+
params.exec = exec;
136136
return params;
137137
}
138138

@@ -287,7 +287,7 @@ void lu_impl(OutputTensor &&out, PivotTensor &&piv,
287287
auto tvt = tv.PermuteMatrix();
288288

289289
// Get parameters required by these tensors
290-
auto params = detail::matxDnLUCUDAPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>::GetLUParams(piv_new, tvt);
290+
auto params = detail::matxDnLUCUDAPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>::GetLUParams(piv_new, tvt, exec);
291291

292292
// Get cache or new LU plan if it doesn't exist
293293
using cache_val_type = detail::matxDnLUCUDAPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>;

include/matx/transforms/qr/qr_cuda.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,7 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t {
295295
MATX_STATIC_ASSERT_STR((std::is_same_v<T1, typename OutTensor_t::value_type>), matxInavlidType, "Input and Output types must match");
296296
MATX_STATIC_ASSERT_STR((std::is_same_v<T1, T2>), matxInavlidType, "A and Tau types must match");
297297

298-
params = GetQRParams(tau, a);
299-
params.exec = exec;
298+
params = GetQRParams(tau, a, exec);
300299
this->GetWorkspaceSize();
301300
this->AllocateWorkspace(params.batch_size, false, exec);
302301
}
@@ -311,7 +310,8 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t {
311310
}
312311

313312
static DnQRCUDAParams_t GetQRParams(TauTensor &tau,
314-
const ATensor &a)
313+
const ATensor &a,
314+
const cudaExecutor &exec)
315315
{
316316
DnQRCUDAParams_t params;
317317

@@ -321,7 +321,7 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t {
321321
params.A = a.Data();
322322
params.tau = tau.Data();
323323
params.dtype = TypeToInt<T1>();
324-
324+
params.exec = exec;
325325
return params;
326326
}
327327

@@ -468,7 +468,7 @@ void qr_solver_impl(OutTensor &&out, TauTensor &&tau,
468468
auto tvt = tv.PermuteMatrix();
469469

470470
// Get parameters required by these tensors
471-
auto params = detail::matxDnQRCUDAPlan_t<OutTensor, decltype(tau_new), decltype(a_new)>::GetQRParams(tau_new, tvt);
471+
auto params = detail::matxDnQRCUDAPlan_t<OutTensor, decltype(tau_new), decltype(a_new)>::GetQRParams(tau_new, tvt, exec);
472472

473473
// Get cache or new QR plan if it doesn't exist
474474
using cache_val_type = detail::matxDnQRCUDAPlan_t<OutTensor, decltype(tau_new), decltype(a_new)>;

include/matx/transforms/svd/svd_cuda.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -644,8 +644,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {
644644
MATX_STATIC_ASSERT_STR(!is_complex_v<T3>, matxInvalidType, "S type must be real");
645645
MATX_STATIC_ASSERT_STR((std::is_same_v<typename inner_op_type_t<T1>::type, T3>), matxInvalidType, "A and S inner types must match");
646646

647-
params = GetSVDParams(u, s, vt, a, jobz);
648-
params.exec = exec;
647+
params = GetSVDParams(u, s, vt, a, jobz, exec);
649648
params.method = method;
650649

651650
if (params.method == SVDMethod::GESVDJ_BATCHED) {
@@ -725,8 +724,8 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {
725724

726725
static DnSVDCUDAParams_t
727726
GetSVDParams(UTensor &u, STensor &s,
728-
VtTensor &vt, const ATensor &a,
729-
const char jobz = 'A')
727+
VtTensor &vt, const ATensor &a,
728+
const char jobz, const cudaExecutor &exec)
730729
{
731730
DnSVDCUDAParams_t params;
732731
params.batch_size = GetNumBatches(a);
@@ -738,6 +737,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {
738737
params.S = s.Data();
739738
params.jobz = jobz;
740739
params.dtype = TypeToInt<T1>();
740+
params.exec = exec;
741741
return params;
742742
}
743743

@@ -997,7 +997,7 @@ void svd_impl(UTensor &&u, STensor &&s,
997997

998998
// Get parameters required by these tensors
999999
auto params = detail::matxDnSVDCUDAPlan_t<decltype(u_in), decltype(s_new), decltype(vt_in), decltype(at_col_maj)>::
1000-
GetSVDParams(u_in, s_new, vt_in, at_col_maj, job_cusolver);
1000+
GetSVDParams(u_in, s_new, vt_in, at_col_maj, job_cusolver, exec);
10011001

10021002
// Get cache or new SVD plan if it doesn't exist
10031003
using cache_val_type = detail::matxDnSVDCUDAPlan_t<decltype(u_in), decltype(s_new), decltype(vt_in), decltype(at_col_maj)>;
@@ -1034,15 +1034,15 @@ void svd_impl(UTensor &&u, STensor &&s,
10341034

10351035
// Get parameters required by these tensors
10361036
auto params = detail::matxDnSVDCUDAPlan_t<decltype(u_col_maj), decltype(s_new), decltype(vt_col_maj), decltype(tvt)>::
1037-
GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, job_cusolver);
1037+
GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, job_cusolver, exec);
10381038

10391039
// Get cache or new SVD plan if it doesn't exist
10401040
using cache_val_type = detail::matxDnSVDCUDAPlan_t<decltype(u_col_maj), decltype(s_new), decltype(vt_col_maj), decltype(tvt)>;
10411041
detail::GetCache().LookupAndExec<detail::svd_cuda_cache_t>(
10421042
detail::GetCacheIdFromType<detail::svd_cuda_cache_t>(),
10431043
params,
10441044
[&]() {
1045-
return std::make_shared<cache_val_type>(u_col_maj, s_new, vt_col_maj, tvt, method, stream, job_cusolver);
1045+
return std::make_shared<cache_val_type>(u_col_maj, s_new, vt_col_maj, tvt, method, exec, job_cusolver);
10461046
},
10471047
[&](std::shared_ptr<cache_val_type> ctype) {
10481048
ctype->Exec(u_col_maj, s_new, vt_col_maj, tvt, exec, job_cusolver);

0 commit comments

Comments
 (0)