@@ -97,7 +97,7 @@ class matxInversePlan_t {
9797 * Inverse of A (if it exists)
9898 *
9999 */
100- matxInversePlan_t (TensorTypeAInv &a_inv, const TensorTypeA &a)
100+ matxInversePlan_t (TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream )
101101 {
102102 static_assert (RANK >= 2 );
103103
@@ -116,65 +116,79 @@ class matxInversePlan_t {
116116 ret = cublasCreate (&handle);
117117 MATX_ASSERT (ret == CUBLAS_STATUS_SUCCESS, matxInverseError);
118118
119- params = GetInverseParams (a_inv, a);
119+ params = GetInverseParams (a_inv, a, stream);
120+
121+ // The cuBLAS getr*Batched LU decomposition functions overwrite the input, so
122+ // we use a temporary buffer to store the inputs.
123+ make_tensor (a_workbuf, a.Shape (), MATX_ASYNC_DEVICE_MEMORY, stream);
120124
121125 if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
122126 // cuBLAS requires a list of pointers to each matrix. Construct that list
123127 // here as our batch dims
124128 std::vector<const T1 *> in_pointers;
125129 std::vector<T1 *> out_pointers;
126130 if constexpr (RANK == 2 ) {
127- in_pointers.push_back (&a (0 , 0 ));
131+ in_pointers.push_back (&a_workbuf (0 , 0 ));
128132 out_pointers.push_back (&a_inv (0 , 0 ));
129133 }
130134 else {
131- using shape_type = typename TensorTypeA::desc_type::shape_type;
135+ using ShapeTypeA = typename decltype (a_workbuf)::desc_type::shape_type;
136+ using ShapeTypeAInv = typename TensorTypeAInv::desc_type::shape_type;
132137 int batch_offset = 2 ;
133- cuda::std::array<shape_type, TensorTypeA::Rank ()> idx{0 };
138+ cuda::std::array<ShapeTypeA, TensorTypeA::Rank ()> a_idx{0 };
139+ cuda::std::array<ShapeTypeAInv, TensorTypeAInv::Rank ()> a_inv_idx{0 };
134140 auto a_shape = a.Shape ();
135141 // Get total number of batches
136- size_t total_iter = std::accumulate (a_shape.begin (), a_shape.begin () + TensorTypeA::Rank () - batch_offset, 1 , std::multiplies<shape_type >());
142+ size_t total_iter = std::accumulate (a_shape.begin (), a_shape.begin () + TensorTypeA::Rank () - batch_offset, 1 , std::multiplies<ShapeTypeA >());
137143 for (size_t iter = 0 ; iter < total_iter; iter++) {
138- auto ip = cuda::std::apply ([&a](auto ... param) { return a.GetPointer (param...); }, idx);
139- auto op = cuda::std::apply ([&a_inv](auto ... param) { return a_inv.GetPointer (param...); }, idx);
140-
144+ auto ip = cuda::std::apply ([&a_workbuf = a_workbuf](auto ... param) { return a_workbuf.GetPointer (param...); }, a_idx);
141145 in_pointers.push_back (ip);
142- out_pointers.push_back (op);
143-
144146 // Update all but the last 2 indices
145- UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank ()>(a, idx, batch_offset);
147+ UpdateIndices<decltype (a_workbuf), ShapeTypeA, TensorTypeA::Rank ()>(a_workbuf, a_idx, batch_offset);
148+
149+ auto op = cuda::std::apply ([&a_inv](auto ... param) { return a_inv.GetPointer (param...); }, a_inv_idx);
150+ out_pointers.push_back (op);
151+ UpdateIndices<TensorTypeAInv, ShapeTypeAInv, TensorTypeAInv::Rank ()>(a_inv, a_inv_idx, batch_offset);
146152 }
147153 }
148154
149155 // Allocate any workspace needed by inverse
150156 matxAlloc ((void **)&d_A_array, in_pointers.size () * sizeof (T1 *),
151- MATX_DEVICE_MEMORY );
152- matxAlloc ((void **)&d_A_inv_array, in_pointers .size () * sizeof (T1 *),
153- MATX_DEVICE_MEMORY );
157+ MATX_ASYNC_DEVICE_MEMORY, stream );
158+ matxAlloc ((void **)&d_A_inv_array, out_pointers .size () * sizeof (T1 *),
159+ MATX_ASYNC_DEVICE_MEMORY, stream );
154160 matxAlloc ((void **)&d_pivot,
155161 a.Size (RANK - 1 ) * in_pointers.size () * sizeof (*d_info),
156- MATX_DEVICE_MEMORY );
162+ MATX_ASYNC_DEVICE_MEMORY, stream );
157163 matxAlloc ((void **)&d_info, in_pointers.size () * sizeof (*d_info),
158- MATX_DEVICE_MEMORY);
159- cudaMemcpy (d_A_array, in_pointers.data (),
160- in_pointers.size () * sizeof (T1 *), cudaMemcpyHostToDevice);
161- cudaMemcpy (d_A_inv_array, out_pointers.data (),
162- out_pointers.size () * sizeof (T1 *), cudaMemcpyHostToDevice);
164+ MATX_ASYNC_DEVICE_MEMORY, stream);
165+ matxAlloc ((void **)&h_info, in_pointers.size () * sizeof (*h_info),
166+ MATX_HOST_MEMORY, stream);
167+ cudaMemcpyAsync (d_A_array, in_pointers.data (),
168+ in_pointers.size () * sizeof (T1 *), cudaMemcpyHostToDevice, stream);
169+ cudaMemcpyAsync (d_A_inv_array, out_pointers.data (),
170+ out_pointers.size () * sizeof (T1 *), cudaMemcpyHostToDevice, stream);
163171 }
164172 else {
165173 MATX_THROW (matxInvalidType, " Invalid inverse algorithm" );
166174 }
167175 }
168176
169177 static InverseParams_t GetInverseParams (TensorTypeAInv &a_inv,
170- const TensorTypeA &a)
178+ const TensorTypeA &a,
179+ cudaStream_t stream)
171180 {
172181 InverseParams_t params;
173- params.A = a.Data ();
182+ if constexpr (is_tensor_view_v<TensorTypeA>) {
183+ params.A = a.Data ();
184+ } else {
185+ params.A = nullptr ;
186+ }
174187 params.A_inv = a_inv.Data ();
175188 params.algo = ALGO;
176189 params.n = a.Size (RANK - 1 );
177190 params.dtype = TypeToInt<T1>();
191+ params.stream = stream;
178192
179193 if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
180194 if constexpr (RANK == 2 ) {
@@ -201,6 +215,7 @@ class matxInversePlan_t {
201215 matxFree (d_A_inv_array, cudaStreamDefault);
202216 matxFree (d_pivot, cudaStreamDefault);
203217 matxFree (d_info, cudaStreamDefault);
218+ matxFree (h_info);
204219
205220 cublasDestroy (handle);
206221 }
@@ -219,12 +234,14 @@ class matxInversePlan_t {
219234 * CUDA stream
220235 *
221236 */
222- inline void Exec (cudaStream_t stream)
237+ inline void Exec ([[maybe_unused]] TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream)
223238 {
224239 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL)
225240
226241 cublasSetStream (handle, stream);
227242
243+ (a_workbuf = a).run (stream);
244+
228245 if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
229246 if constexpr (std::is_same_v<T1, float >) {
230247 ret =
@@ -242,7 +259,6 @@ class matxInversePlan_t {
242259 reinterpret_cast <cuComplex *const *>(d_A_array),
243260 static_cast <int >(params.n ), d_pivot, d_info,
244261 static_cast <int >(params.batch_size ));
245- cudaDeviceSynchronize ();
246262 }
247263 else if constexpr (std::is_same_v<T1, cuda::std::complex <double >>) {
248264 ret = cublasZgetrfBatched (
@@ -254,10 +270,13 @@ class matxInversePlan_t {
254270
255271 MATX_ASSERT (ret == CUBLAS_STATUS_SUCCESS, matxLUError);
256272
257- int h_info = 0 ;
258- cudaMemcpy (&h_info, d_info, sizeof (int ), cudaMemcpyDeviceToHost);
259-
260- MATX_ASSERT (h_info == 0 , matxLUError);
273+ cudaMemcpyAsync (h_info, d_info, sizeof (int ) * params.batch_size , cudaMemcpyDeviceToHost, stream);
274+ cudaStreamSynchronize (stream);
275+ for (size_t i = 0 ; i < params.batch_size ; i++) {
276+ if (h_info[i] != 0 ) {
277+ MATX_THROW (matxLUError, " inverse failed" );
278+ }
279+ }
261280
262281 if constexpr (std::is_same_v<T1, float >) {
263282 ret = cublasSgetriBatched (handle, static_cast <int >(params.n ), d_A_array,
@@ -292,8 +311,13 @@ class matxInversePlan_t {
292311
293312 MATX_ASSERT (ret == CUBLAS_STATUS_SUCCESS, matxInverseError);
294313
295- cudaMemcpy (&h_info, d_info, sizeof (int ), cudaMemcpyDeviceToHost);
296- MATX_ASSERT (h_info == 0 , matxInverseError);
314+ cudaMemcpyAsync (h_info, d_info, sizeof (int ) * params.batch_size , cudaMemcpyDeviceToHost, stream);
315+ cudaStreamSynchronize (stream);
316+ for (size_t i = 0 ; i < params.batch_size ; i++) {
317+ if (h_info[i] != 0 ) {
318+ MATX_THROW (matxLUError, " inverse failed" );
319+ }
320+ }
297321 }
298322 }
299323
@@ -303,8 +327,10 @@ class matxInversePlan_t {
303327
304328 InverseParams_t params;
305329 cublasHandle_t handle;
330+ matx::tensor_t <typename TensorTypeA::value_type, TensorTypeA::Rank()> a_workbuf;
306331 int *d_pivot;
307332 int *d_info;
333+ int *h_info;
308334 T1 **d_A_array;
309335 T1 **d_A_inv_array;
310336};
@@ -360,18 +386,17 @@ void inv_impl(TensorTypeAInv &a_inv, const TensorTypeA &a,
360386 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
361387 static_assert (TensorTypeAInv::Rank () == TensorTypeA::Rank (), " Input and output ranks must match" );
362388 // Get parameters required by these tensors
363- auto params = detail::matxInversePlan_t<TensorTypeAInv, TensorTypeA, ALGO>::GetInverseParams (a_inv, a);
364- params.stream = stream;
389+ auto params = detail::matxInversePlan_t<TensorTypeAInv, TensorTypeA, ALGO>::GetInverseParams (a_inv, a, stream);
365390
366391 using cache_val_type = detail::matxInversePlan_t<TensorTypeAInv, TensorTypeA, ALGO>;
367392 detail::GetCache ().LookupAndExec <detail::inv_cache_t >(
368393 detail::GetCacheIdFromType<detail::inv_cache_t >(),
369394 params,
370395 [&]() {
371- return std::make_shared<cache_val_type>(a_inv, a);
396+ return std::make_shared<cache_val_type>(a_inv, a, stream );
372397 },
373398 [&](std::shared_ptr<cache_val_type> ctype) {
374- ctype->Exec (stream);
399+ ctype->Exec (a_inv, a, stream);
375400 }
376401 );
377402}
0 commit comments