Skip to content

Commit 3df9230

Browse files
committed
Add in-place transform support for inv()
Change the behavior of the inv() transform as follows: - No longer unconditionally overwrite the input with factorized data. Previously, (Ainv = inv(A)).run() would write the inverse to A and the LU factorization to A. - Support in-place transforms like (A = inv(A)).run(). Previously, this would run, but the results would be incorrect because the underlying cuBLAS calls only support out-of-place solves. - The above are achieved by always creating a temporary workbuffer and copying the input into that work buffer. - Add support for input operators (i.e., not just tensors). The operator runs when populating the temporary input work buffer.
1 parent 459cffb commit 3df9230

File tree

2 files changed

+186
-41
lines changed

2 files changed

+186
-41
lines changed

include/matx/transforms/inverse.h

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class matxInversePlan_t {
9797
* Inverse of A (if it exists)
9898
*
9999
*/
100-
matxInversePlan_t(TensorTypeAInv &a_inv, const TensorTypeA &a)
100+
matxInversePlan_t(TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream)
101101
{
102102
static_assert(RANK >= 2);
103103

@@ -116,65 +116,79 @@ class matxInversePlan_t {
116116
ret = cublasCreate(&handle);
117117
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxInverseError);
118118

119-
params = GetInverseParams(a_inv, a);
119+
params = GetInverseParams(a_inv, a, stream);
120+
121+
// The cuBLAS getr*Batched LU decomposition functions overwrite the input, so
122+
// we use a temporary buffer to store the inputs.
123+
make_tensor(a_workbuf, a.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
120124

121125
if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
122126
// cuBLAS requires a list of pointers to each matrix. Construct that list
123127
// here as our batch dims
124128
std::vector<const T1 *> in_pointers;
125129
std::vector<T1 *> out_pointers;
126130
if constexpr (RANK == 2) {
127-
in_pointers.push_back(&a(0, 0));
131+
in_pointers.push_back(&a_workbuf(0, 0));
128132
out_pointers.push_back(&a_inv(0, 0));
129133
}
130134
else {
131-
using shape_type = typename TensorTypeA::desc_type::shape_type;
135+
using ShapeTypeA = typename decltype(a_workbuf)::desc_type::shape_type;
136+
using ShapeTypeAInv = typename TensorTypeAInv::desc_type::shape_type;
132137
int batch_offset = 2;
133-
cuda::std::array<shape_type, TensorTypeA::Rank()> idx{0};
138+
cuda::std::array<ShapeTypeA, TensorTypeA::Rank()> a_idx{0};
139+
cuda::std::array<ShapeTypeAInv, TensorTypeAInv::Rank()> a_inv_idx{0};
134140
auto a_shape = a.Shape();
135141
// Get total number of batches
136-
size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorTypeA::Rank() - batch_offset, 1, std::multiplies<shape_type>());
142+
size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorTypeA::Rank() - batch_offset, 1, std::multiplies<ShapeTypeA>());
137143
for (size_t iter = 0; iter < total_iter; iter++) {
138-
auto ip = cuda::std::apply([&a](auto... param) { return a.GetPointer(param...); }, idx);
139-
auto op = cuda::std::apply([&a_inv](auto... param) { return a_inv.GetPointer(param...); }, idx);
140-
144+
auto ip = cuda::std::apply([&a_workbuf = a_workbuf](auto... param) { return a_workbuf.GetPointer(param...); }, a_idx);
141145
in_pointers.push_back(ip);
142-
out_pointers.push_back(op);
143-
144146
// Update all but the last 2 indices
145-
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a, idx, batch_offset);
147+
UpdateIndices<decltype(a_workbuf), ShapeTypeA, TensorTypeA::Rank()>(a_workbuf, a_idx, batch_offset);
148+
149+
auto op = cuda::std::apply([&a_inv](auto... param) { return a_inv.GetPointer(param...); }, a_inv_idx);
150+
out_pointers.push_back(op);
151+
UpdateIndices<TensorTypeAInv, ShapeTypeAInv, TensorTypeAInv::Rank()>(a_inv, a_inv_idx, batch_offset);
146152
}
147153
}
148154

149155
// Allocate any workspace needed by inverse
150156
matxAlloc((void **)&d_A_array, in_pointers.size() * sizeof(T1 *),
151-
MATX_DEVICE_MEMORY);
152-
matxAlloc((void **)&d_A_inv_array, in_pointers.size() * sizeof(T1 *),
153-
MATX_DEVICE_MEMORY);
157+
MATX_ASYNC_DEVICE_MEMORY, stream);
158+
matxAlloc((void **)&d_A_inv_array, out_pointers.size() * sizeof(T1 *),
159+
MATX_ASYNC_DEVICE_MEMORY, stream);
154160
matxAlloc((void **)&d_pivot,
155161
a.Size(RANK - 1) * in_pointers.size() * sizeof(*d_info),
156-
MATX_DEVICE_MEMORY);
162+
MATX_ASYNC_DEVICE_MEMORY, stream);
157163
matxAlloc((void **)&d_info, in_pointers.size() * sizeof(*d_info),
158-
MATX_DEVICE_MEMORY);
159-
cudaMemcpy(d_A_array, in_pointers.data(),
160-
in_pointers.size() * sizeof(T1 *), cudaMemcpyHostToDevice);
161-
cudaMemcpy(d_A_inv_array, out_pointers.data(),
162-
out_pointers.size() * sizeof(T1 *), cudaMemcpyHostToDevice);
164+
MATX_ASYNC_DEVICE_MEMORY, stream);
165+
matxAlloc((void **)&h_info, in_pointers.size() * sizeof(*h_info),
166+
MATX_HOST_MEMORY, stream);
167+
cudaMemcpyAsync(d_A_array, in_pointers.data(),
168+
in_pointers.size() * sizeof(T1 *), cudaMemcpyHostToDevice, stream);
169+
cudaMemcpyAsync(d_A_inv_array, out_pointers.data(),
170+
out_pointers.size() * sizeof(T1 *), cudaMemcpyHostToDevice, stream);
163171
}
164172
else {
165173
MATX_THROW(matxInvalidType, "Invalid inverse algorithm");
166174
}
167175
}
168176

169177
static InverseParams_t GetInverseParams(TensorTypeAInv &a_inv,
170-
const TensorTypeA &a)
178+
const TensorTypeA &a,
179+
cudaStream_t stream)
171180
{
172181
InverseParams_t params;
173-
params.A = a.Data();
182+
if constexpr (is_tensor_view_v<TensorTypeA>) {
183+
params.A = a.Data();
184+
} else {
185+
params.A = nullptr;
186+
}
174187
params.A_inv = a_inv.Data();
175188
params.algo = ALGO;
176189
params.n = a.Size(RANK - 1);
177190
params.dtype = TypeToInt<T1>();
191+
params.stream = stream;
178192

179193
if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
180194
if constexpr (RANK == 2) {
@@ -201,6 +215,7 @@ class matxInversePlan_t {
201215
matxFree(d_A_inv_array, cudaStreamDefault);
202216
matxFree(d_pivot, cudaStreamDefault);
203217
matxFree(d_info, cudaStreamDefault);
218+
matxFree(h_info);
204219

205220
cublasDestroy(handle);
206221
}
@@ -219,12 +234,14 @@ class matxInversePlan_t {
219234
* CUDA stream
220235
*
221236
*/
222-
inline void Exec(cudaStream_t stream)
237+
inline void Exec([[maybe_unused]] TensorTypeAInv &a_inv, const TensorTypeA &a, cudaStream_t stream)
223238
{
224239
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
225240

226241
cublasSetStream(handle, stream);
227242

243+
(a_workbuf = a).run(stream);
244+
228245
if constexpr (ALGO == MAT_INVERSE_ALGO_LU) {
229246
if constexpr (std::is_same_v<T1, float>) {
230247
ret =
@@ -242,7 +259,6 @@ class matxInversePlan_t {
242259
reinterpret_cast<cuComplex *const *>(d_A_array),
243260
static_cast<int>(params.n), d_pivot, d_info,
244261
static_cast<int>(params.batch_size));
245-
cudaDeviceSynchronize();
246262
}
247263
else if constexpr (std::is_same_v<T1, cuda::std::complex<double>>) {
248264
ret = cublasZgetrfBatched(
@@ -254,10 +270,13 @@ class matxInversePlan_t {
254270

255271
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxLUError);
256272

257-
int h_info = 0;
258-
cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost);
259-
260-
MATX_ASSERT(h_info == 0, matxLUError);
273+
cudaMemcpyAsync(h_info, d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
274+
cudaStreamSynchronize(stream);
275+
for (size_t i = 0; i < params.batch_size; i++) {
276+
if (h_info[i] != 0) {
277+
MATX_THROW(matxLUError, "inverse failed");
278+
}
279+
}
261280

262281
if constexpr (std::is_same_v<T1, float>) {
263282
ret = cublasSgetriBatched(handle, static_cast<int>(params.n), d_A_array,
@@ -292,8 +311,13 @@ class matxInversePlan_t {
292311

293312
MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxInverseError);
294313

295-
cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost);
296-
MATX_ASSERT(h_info == 0, matxInverseError);
314+
cudaMemcpyAsync(h_info, d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
315+
cudaStreamSynchronize(stream);
316+
for (size_t i = 0; i < params.batch_size; i++) {
317+
if (h_info[i] != 0) {
318+
MATX_THROW(matxLUError, "inverse failed");
319+
}
320+
}
297321
}
298322
}
299323

@@ -303,8 +327,10 @@ class matxInversePlan_t {
303327

304328
InverseParams_t params;
305329
cublasHandle_t handle;
330+
matx::tensor_t<typename TensorTypeA::value_type, TensorTypeA::Rank()> a_workbuf;
306331
int *d_pivot;
307332
int *d_info;
333+
int *h_info;
308334
T1 **d_A_array;
309335
T1 **d_A_inv_array;
310336
};
@@ -360,18 +386,17 @@ void inv_impl(TensorTypeAInv &a_inv, const TensorTypeA &a,
360386
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
361387
static_assert(TensorTypeAInv::Rank() == TensorTypeA::Rank(), "Input and output ranks must match");
362388
// Get parameters required by these tensors
363-
auto params = detail::matxInversePlan_t<TensorTypeAInv, TensorTypeA, ALGO>::GetInverseParams(a_inv, a);
364-
params.stream = stream;
389+
auto params = detail::matxInversePlan_t<TensorTypeAInv, TensorTypeA, ALGO>::GetInverseParams(a_inv, a, stream);
365390

366391
using cache_val_type = detail::matxInversePlan_t<TensorTypeAInv, TensorTypeA, ALGO>;
367392
detail::GetCache().LookupAndExec<detail::inv_cache_t>(
368393
detail::GetCacheIdFromType<detail::inv_cache_t>(),
369394
params,
370395
[&]() {
371-
return std::make_shared<cache_val_type>(a_inv, a);
396+
return std::make_shared<cache_val_type>(a_inv, a, stream);
372397
},
373398
[&](std::shared_ptr<cache_val_type> ctype) {
374-
ctype->Exec(stream);
399+
ctype->Exec(a_inv, a, stream);
375400
}
376401
);
377402
}

0 commit comments

Comments
 (0)