Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/matx/operators/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ namespace matx
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) noexcept
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept
{
if constexpr (is_matx_op<OpA>()) {
a_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/chol/chol_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {

params = GetCholParams(a, uplo);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
this->AllocateWorkspace(params.batch_size, false);
}

void GetWorkspaceSize() override
Expand Down
45 changes: 37 additions & 8 deletions include/matx/transforms/eig/eig_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,31 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {

params = GetEigParams(w, a, jobz, uplo);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
#if CUSOLVER_VERSION > 11701 || (CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2)
this->AllocateWorkspace(params.batch_size, true);
#else
this->AllocateWorkspace(params.batch_size, false);
#endif
}

void GetWorkspaceSize() override
{
#if CUSOLVER_VERSION > 11701 || (CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2)
// Use vector mode for a larger workspace size that works for both modes
cusolverStatus_t ret = cusolverDnXsyevBatched_bufferSize(
this->handle, this->dn_params, CUSOLVER_EIG_MODE_VECTOR,
params.uplo, params.n, MatXTypeToCudaType<T1>(), params.A,
params.n, MatXTypeToCudaType<T2>(), params.W,
MatXTypeToCudaType<T1>(), &this->dspace,
&this->hspace, params.batch_size);
#else
cusolverStatus_t ret = cusolverDnXsyevd_bufferSize(
this->handle, this->dn_params, CUSOLVER_EIG_MODE_VECTOR,
params.uplo, params.n, MatXTypeToCudaType<T1>(), params.A,
params.n, MatXTypeToCudaType<T2>(), params.W,
MatXTypeToCudaType<T1>(), &this->dspace,
&this->hspace);
#endif

MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError);
}
Expand Down Expand Up @@ -166,18 +179,32 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
}
}

SetBatchPointers<BatchType::MATRIX>(out, this->batch_a_ptrs);
SetBatchPointers<BatchType::VECTOR>(w, this->batch_w_ptrs);

if (out.Data() != a.Data()) {
(out = a).run(exec);
}

const auto stream = exec.getStream();
cusolverDnSetStream(this->handle, stream);

// At this time cuSolver does not have a batched 64-bit LU interface. Change
// this to use the batched version once available.
#if CUSOLVER_VERSION > 11701 || ( CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2)
auto ret = cusolverDnXsyevBatched(
this->handle, this->dn_params, jobz, uplo, params.n, MatXTypeToCudaType<T1>(),
out.Data(), params.n, MatXTypeToCudaType<T2>(), w.Data(),
MatXTypeToCudaType<T1>(),
reinterpret_cast<uint8_t *>(this->d_workspace), this->dspace,
reinterpret_cast<uint8_t *>(this->h_workspace), this->hspace,
this->d_info, params.batch_size);

MATX_ASSERT_STR_EXP(ret, CUSOLVER_STATUS_SUCCESS, matxSolverError,
("cusolverDnXsyevBatched failed with error " + std::to_string(ret)).c_str());

std::vector<int> h_info(params.batch_size);
cudaMemcpyAsync(h_info.data(), this->d_info, sizeof(int) * params.batch_size, cudaMemcpyDeviceToHost, stream);
#else
SetBatchPointers<BatchType::MATRIX>(out, this->batch_a_ptrs);
SetBatchPointers<BatchType::VECTOR>(w, this->batch_w_ptrs);

// Older cuSolver versions do not support batching with cusolverDnXsyevd
for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) {
auto ret = cusolverDnXsyevd(
this->handle, this->dn_params, jobz, uplo, params.n, MatXTypeToCudaType<T1>(),
Expand All @@ -187,11 +214,13 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
reinterpret_cast<uint8_t *>(this->h_workspace) + i * this->hspace, this->hspace,
this->d_info + i);

MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError);
MATX_ASSERT_STR_EXP(ret, CUSOLVER_STATUS_SUCCESS, matxSolverError,
("cusolverDnXsyevBatched failed with error " + std::to_string(ret)).c_str());
}

std::vector<int> h_info(this->batch_a_ptrs.size());
cudaMemcpyAsync(h_info.data(), this->d_info, sizeof(int) * this->batch_a_ptrs.size(), cudaMemcpyDeviceToHost, stream);
#endif

// This will block. Figure this out later
cudaStreamSynchronize(stream);
Expand Down Expand Up @@ -330,4 +359,4 @@ void eig_impl(OutputTensor &&out, WTensor &&w,
matxFree(tp);
}

} // end namespace matx
} // end namespace matx
2 changes: 1 addition & 1 deletion include/matx/transforms/eig/eig_lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class matxDnEigHostPlan_t : matxDnHostSolver_t<typename ATensor::value_type> {

params = GetEigParams(w, a, jobz, uplo);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
this->AllocateWorkspace(params.batch_size, false);
}

void GetWorkspaceSize() override
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/lu/lu_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {

params = GetLUParams(piv, a);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
this->AllocateWorkspace(params.batch_size, false);
}

void GetWorkspaceSize() override
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/qr/qr_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t {

params = GetQRParams(tau, a);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
this->AllocateWorkspace(params.batch_size, false);
}

void GetWorkspaceSize() override
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/qr/qr_lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class matxDnQRHostPlan_t : matxDnHostSolver_t<typename ATensor::value_type> {

params = GetQRParams(tau, a);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
this->AllocateWorkspace(params.batch_size, false);
}

void GetWorkspaceSize() override
Expand Down
37 changes: 29 additions & 8 deletions include/matx/transforms/solver_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,38 @@ class matxDnCUDASolver_t {
cusolverDnDestroy(handle);
}

void AllocateWorkspace(size_t batches)
void AllocateWorkspace([[maybe_unused]] size_t batches, [[maybe_unused]] bool batched_api)
{
if (dspace > 0) {
matxAlloc(&d_workspace, batches * dspace, MATX_DEVICE_MEMORY);
#if CUSOLVER_VERSION > 11701 || ( CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2)
if (batched_api) {
// Newer cuSolver
if (dspace > 0) {
matxAlloc(&d_workspace, dspace, MATX_DEVICE_MEMORY);
}

// cuSolver has a bug where the workspace needs to be zeroed before using it when the type is complex.
// Zero it out for all types for now.
cudaMemset(d_workspace, 0, dspace);
matxAlloc((void **)&d_info, sizeof(*d_info) * batches, MATX_DEVICE_MEMORY);

if (hspace > 0) {
matxAlloc(&h_workspace, hspace, MATX_HOST_MEMORY);
}
}
else {
#endif
if (dspace > 0) {
matxAlloc(&d_workspace, batches * dspace, MATX_DEVICE_MEMORY);
}

matxAlloc((void **)&d_info, batches * sizeof(*d_info), MATX_DEVICE_MEMORY);
matxAlloc((void **)&d_info, batches * sizeof(*d_info), MATX_DEVICE_MEMORY);

if (hspace > 0) {
matxAlloc(&h_workspace, batches * hspace, MATX_HOST_MEMORY);
if (hspace > 0) {
matxAlloc(&h_workspace, batches * hspace, MATX_HOST_MEMORY);
}
#if CUSOLVER_VERSION > 11701 || ( CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2)
}
#endif
}

virtual void GetWorkspaceSize() = 0;
Expand Down Expand Up @@ -307,7 +328,7 @@ class matxDnHostSolver_t {
matxFree(iwork);
}

void AllocateWorkspace([[maybe_unused]] size_t batches)
void AllocateWorkspace([[maybe_unused]] size_t batches, [[maybe_unused]] bool batched_api)
{
if (lwork > 0) {
matxAlloc(&work, lwork * sizeof(ValueType), MATX_HOST_MALLOC_MEMORY);
Expand Down Expand Up @@ -339,4 +360,4 @@ class matxDnHostSolver_t {

} // end namespace detail

} // end namespace matx
} // end namespace matx
2 changes: 1 addition & 1 deletion include/matx/transforms/svd/svd_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t {

params = GetSVDParams(u, s, vt, a, jobz);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
this->AllocateWorkspace(params.batch_size, false);
}

void GetWorkspaceSize() override
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/svd/svd_lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class matxDnSVDHostPlan_t : matxDnHostSolver_t<typename ATensor::value_type> {

params = GetSVDParams(u, s, vt, a, jobz, algo);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size);
this->AllocateWorkspace(params.batch_size, false);
}

void GetWorkspaceSize() override
Expand Down