|
| 1 | +// Copyright © 2025 Apple Inc. |
| 2 | + |
| 3 | +#include "mlx/backend/cuda/cublas_utils.h" |
| 4 | +#include "mlx/backend/cuda/cuda.h" |
| 5 | +#include "mlx/utils.h" |
| 6 | + |
| 7 | +namespace mlx::core { |
| 8 | +namespace cublas_utils { |
| 9 | + |
| 10 | +namespace { |
| 11 | + |
| 12 | +struct CublasPreference { |
| 13 | + CublasPreference(cu::Device& device) { |
| 14 | + // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB |
| 15 | + // for Hopper+: |
| 16 | + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace |
| 17 | + uint64_t MiB = 1024 * 1024; |
| 18 | + uint64_t workspace_size = |
| 19 | + device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; |
| 20 | + |
| 21 | + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); |
| 22 | + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( |
| 23 | + pref_, |
| 24 | + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, |
| 25 | + &workspace_size, |
| 26 | + sizeof(uint64_t))); |
| 27 | + } |
| 28 | + |
| 29 | + ~CublasPreference() { |
| 30 | + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); |
| 31 | + } |
| 32 | + |
| 33 | + cublasLtMatmulPreference_t pref_{nullptr}; |
| 34 | +}; |
| 35 | + |
| 36 | +} // namespace |
| 37 | + |
| 38 | +cublasLtMatmulPreference_t get_preference(cu::Device& device) { |
| 39 | + static CublasPreference pref(device); |
| 40 | + return pref.pref_; |
| 41 | +} |
| 42 | + |
| 43 | +void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) { |
| 44 | + if (workspace_size == 0) { |
| 45 | + return nullptr; |
| 46 | + } |
| 47 | + |
| 48 | + // Ensure workspace is 256-byte aligned |
| 49 | + int nbytes = cuda::ceil_div(workspace_size, 256) * 256; |
| 50 | + array workspace( |
| 51 | + cu::malloc_async(nbytes, encoder), |
| 52 | + {static_cast<int>(workspace_size)}, |
| 53 | + int8); |
| 54 | + encoder.add_temporary(workspace); |
| 55 | + return gpu_ptr<void>(workspace); |
| 56 | +} |
| 57 | + |
| 58 | +cublasLtMatrixLayout_t create_matrix_layout( |
| 59 | + cudaDataType_t type, |
| 60 | + uint64_t rows, |
| 61 | + uint64_t cols, |
| 62 | + bool transposed, |
| 63 | + int64_t ld, |
| 64 | + int32_t batch_count, |
| 65 | + int64_t batch_stride) { |
| 66 | + cublasLtMatrixLayout_t desc; |
| 67 | + if (transposed) { |
| 68 | + std::swap(rows, cols); |
| 69 | + } |
| 70 | + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); |
| 71 | + if (batch_count > 1) { |
| 72 | + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( |
| 73 | + desc, |
| 74 | + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, |
| 75 | + &batch_count, |
| 76 | + sizeof(int32_t))); |
| 77 | + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( |
| 78 | + desc, |
| 79 | + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, |
| 80 | + &batch_stride, |
| 81 | + sizeof(int64_t))); |
| 82 | + } |
| 83 | + return desc; |
| 84 | +} |
| 85 | + |
| 86 | +} // namespace cublas_utils |
| 87 | + |
| 88 | +CublasMatmulBase::~CublasMatmulBase() { |
| 89 | + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); |
| 90 | + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); |
| 91 | + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); |
| 92 | + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); |
| 93 | + CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); |
| 94 | +} |
| 95 | + |
| 96 | +void CublasMatmulBase::init_base( |
| 97 | + cu::Device& device, |
| 98 | + cudaDataType_t scale_type, |
| 99 | + cublasComputeType_t compute_type, |
| 100 | + cudaDataType_t data_type, |
| 101 | + cudaDataType_t output_type, |
| 102 | + bool a_transposed, |
| 103 | + uint64_t a_rows, |
| 104 | + uint64_t a_cols, |
| 105 | + int64_t lda, |
| 106 | + bool b_transposed, |
| 107 | + uint64_t b_rows, |
| 108 | + uint64_t b_cols, |
| 109 | + int64_t ldb, |
| 110 | + int32_t batch_count, |
| 111 | + int64_t a_batch_stride, |
| 112 | + int64_t b_batch_stride) { |
| 113 | + M_ = a_rows; |
| 114 | + N_ = b_cols; |
| 115 | + scale_type_ = scale_type; |
| 116 | + handle_ = device.lt_handle(); |
| 117 | + pref_ = cublas_utils::get_preference(device); |
| 118 | + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; |
| 119 | + |
| 120 | + CHECK_CUBLAS_ERROR( |
| 121 | + cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type)); |
| 122 | + |
| 123 | + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; |
| 124 | + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |
| 125 | + matmul_desc_, |
| 126 | + CUBLASLT_MATMUL_DESC_POINTER_MODE, |
| 127 | + &pointer_mode, |
| 128 | + sizeof(int32_t))); |
| 129 | + |
| 130 | + // In cublasLt matrices use column-major layout, while it is possible to use |
| 131 | + // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias |
| 132 | + // epilogue does not work with the option. So instead we swap A and B to make |
| 133 | + // cublasLt return the row-major result, which works because: |
| 134 | + // - the data of a matrix in row-major layout is identical to its transpose in |
| 135 | + // column-major layout |
| 136 | + // - C^T = (A @ B)^T = B^T @ A^T |
| 137 | + cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; |
| 138 | + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |
| 139 | + matmul_desc_, |
| 140 | + CUBLASLT_MATMUL_DESC_TRANSA, |
| 141 | + &a_op, |
| 142 | + sizeof(cublasOperation_t))); |
| 143 | + cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; |
| 144 | + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |
| 145 | + matmul_desc_, |
| 146 | + CUBLASLT_MATMUL_DESC_TRANSB, |
| 147 | + &b_op, |
| 148 | + sizeof(cublasOperation_t))); |
| 149 | + |
| 150 | + a_desc_ = cublas_utils::create_matrix_layout( |
| 151 | + data_type, |
| 152 | + b_cols, |
| 153 | + b_rows, |
| 154 | + b_transposed, |
| 155 | + ldb, |
| 156 | + batch_count, |
| 157 | + b_batch_stride); |
| 158 | + b_desc_ = cublas_utils::create_matrix_layout( |
| 159 | + data_type, |
| 160 | + a_cols, |
| 161 | + a_rows, |
| 162 | + a_transposed, |
| 163 | + lda, |
| 164 | + batch_count, |
| 165 | + a_batch_stride); |
| 166 | + out_desc_ = cublas_utils::create_matrix_layout( |
| 167 | + output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows); |
| 168 | +} |
| 169 | + |
| 170 | +void CublasMatmulBase::execute_matmul( |
| 171 | + cu::CommandEncoder& encoder, |
| 172 | + void* out, |
| 173 | + const void* a, |
| 174 | + const void* b, |
| 175 | + const void* c, |
| 176 | + const void* alpha_ptr, |
| 177 | + const void* beta_ptr) { |
| 178 | + if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { |
| 179 | + int ret = 0; |
| 180 | + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( |
| 181 | + handle_, |
| 182 | + matmul_desc_, |
| 183 | + a_desc_, |
| 184 | + b_desc_, |
| 185 | + c ? c_desc_ : out_desc_, |
| 186 | + out_desc_, |
| 187 | + pref_, |
| 188 | + 1, |
| 189 | + &heuristic_, |
| 190 | + &ret)); |
| 191 | + if (ret == 0) { |
| 192 | + throw std::runtime_error("Can not find algorithm for matmul."); |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + void* workspace_ptr = |
| 197 | + cublas_utils::allocate_workspace(encoder, heuristic_.workspaceSize); |
| 198 | + |
| 199 | + // Execute matmul |
| 200 | + auto capture = encoder.capture_context(); |
| 201 | + CHECK_CUBLAS_ERROR(cublasLtMatmul( |
| 202 | + handle_, |
| 203 | + matmul_desc_, |
| 204 | + alpha_ptr, |
| 205 | + b, // a and b are swapped for row-major layout |
| 206 | + a_desc_, |
| 207 | + a, |
| 208 | + b_desc_, |
| 209 | + beta_ptr, |
| 210 | + c ? c : out, |
| 211 | + c ? c_desc_ : out_desc_, |
| 212 | + out, |
| 213 | + out_desc_, |
| 214 | + &heuristic_.algo, |
| 215 | + workspace_ptr, |
| 216 | + heuristic_.workspaceSize, |
| 217 | + encoder.stream())); |
| 218 | +} |
| 219 | + |
| 220 | +void CublasMatmulBase::set_bias( |
| 221 | + cu::CommandEncoder& encoder, |
| 222 | + const array& bias) { |
| 223 | + encoder.set_input_array(bias); |
| 224 | + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; |
| 225 | + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |
| 226 | + matmul_desc_, |
| 227 | + CUBLASLT_MATMUL_DESC_EPILOGUE, |
| 228 | + &epilogue, |
| 229 | + sizeof(epilogue))); |
| 230 | + auto* bias_ptr = gpu_ptr<void>(bias); |
| 231 | + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |
| 232 | + matmul_desc_, |
| 233 | + CUBLASLT_MATMUL_DESC_BIAS_POINTER, |
| 234 | + &bias_ptr, |
| 235 | + sizeof(bias_ptr))); |
| 236 | +} |
| 237 | + |
| 238 | +} // namespace mlx::core |
0 commit comments