Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions include/matx/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "matx/core/defines.h"
#include "matx/core/error.h"

#define HOPPER_CC 9
#define AMPERE_CC 8
#define VOLTA_CC 7
#define PASCAL_CC 6
Expand All @@ -47,8 +48,9 @@ namespace detail {
__MATX_INLINE__ int GetDeviceAttr(cudaDeviceAttr attr) {
int val;
int dev;
cudaGetDevice(&dev);
[[maybe_unused]] auto err = cudaDeviceGetAttribute(&val, attr, dev);
[[maybe_unused]] auto err = cudaGetDevice(&dev);
MATX_ASSERT(err == cudaSuccess, matxCudaError);
err = cudaDeviceGetAttribute(&val, attr, dev);
MATX_ASSERT(err == cudaSuccess, matxCudaError);
return val;
}
Expand All @@ -57,6 +59,10 @@ __MATX_INLINE__ int GetComputeCapabilityMajor() {
return GetDeviceAttr(cudaDevAttrComputeCapabilityMajor);
}

__MATX_INLINE__ bool IsHopperOrAbove() {
return GetComputeCapabilityMajor() >= HOPPER_CC;
}

__MATX_INLINE__ bool IsAmpereOrAbove() {
return GetComputeCapabilityMajor() >= AMPERE_CC;
}
Expand Down
15 changes: 11 additions & 4 deletions include/matx/transforms/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,17 @@ class matxMatMulHandle_t {
// This must come before the things below to properly set class parameters
params_ = GetGemmParams(c, a, b);

// // Workspace buffer
matxAlloc((void **)&workspace, workspaceSize, MATX_DEVICE_MEMORY);

if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
// Thus, try to detect if we are running on Hopper or newer and use a 32 MiB workspace
// if so. Otherwise, default to 4 MiB, which still works on Hopper+.
constexpr size_t MiB = 1024*1024;
workspaceSize = detail::IsHopperOrAbove() ? 32*MiB : 4*MiB;

// Workspace buffer
matxAlloc((void **)&workspace, workspaceSize, MATX_DEVICE_MEMORY);

ConfigureCublasLt();
}
}
Expand Down Expand Up @@ -510,7 +517,7 @@ class matxMatMulHandle_t {
void *c_hp = nullptr; // Make these void since they only work on complex types
void *a_hp = nullptr;
void *b_hp = nullptr;
size_t workspaceSize = 1 << 22UL; // 4MB buffer suggested by cuBLAS team
size_t workspaceSize = 0;
void *workspace = nullptr;
detail::MatMulParams_t params_;

Expand Down