Skip to content

Commit 262fdd1

Browse files
zcbenzfaisalmemon
authored andcommitted
[CUDA] Use GEMM with epilogue instead of AddMM (ml-explore#2569)
1 parent 0f5b9a8 commit 262fdd1

File tree

3 files changed

+110
-51
lines changed

3 files changed

+110
-51
lines changed

mlx/backend/cuda/gemms/cublas_gemm.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
8585
int32_t batch_count,
8686
int64_t batch_stride) {
8787
cublasLtMatrixLayout_t desc;
88+
if (transposed) {
89+
std::swap(rows, cols);
90+
}
8891
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
89-
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
90-
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
91-
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
9292
if (batch_count > 1) {
9393
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
9494
desc,
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
138138
CUBLASLT_MATMUL_DESC_POINTER_MODE,
139139
&pointer_mode,
140140
sizeof(int32_t)));
141-
cublasOperation_t op = CUBLAS_OP_N;
141+
142+
// In cublasLt matrices use column-major layout, while it is possible to use
143+
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
144+
// epilogue does not work with the option. So instead we swap A and B to make
145+
// cublasLt return the row-major result, which works because:
146+
// - the data of a matrix in row-major layout is identical to its transpose in
147+
// column-major layout
148+
// - C^T = (A @ B)^T = B^T @ A^T
149+
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
142150
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
143151
matmul_desc_,
144152
CUBLASLT_MATMUL_DESC_TRANSA,
145-
&op,
153+
&a_op,
146154
sizeof(cublasOperation_t)));
155+
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
147156
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
148157
matmul_desc_,
149158
CUBLASLT_MATMUL_DESC_TRANSB,
150-
&op,
159+
&b_op,
151160
sizeof(cublasOperation_t)));
152161

153162
auto type = dtype_to_cublas_type(dtype);
154163
a_desc_ = create_matrix_layout(
155-
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
164+
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
156165
b_desc_ = create_matrix_layout(
157-
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
166+
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
158167
out_desc_ = create_matrix_layout(
159-
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
168+
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
160169
}
161170

162171
CublasGemm::CublasGemm(
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
191200
b_batch_stride) {
192201
auto type = dtype_to_cublas_type(dtype);
193202
c_desc_ = create_matrix_layout(
194-
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
203+
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
195204
}
196205

197206
CublasGemm::~CublasGemm() {
@@ -213,14 +222,25 @@ void CublasGemm::set_out(
213222
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
214223
out_desc_ = create_matrix_layout(
215224
dtype_to_cublas_type(dtype),
216-
rows,
217225
cols,
226+
rows,
218227
transposed,
219228
ld,
220229
batch_count,
221230
batch_stride);
222231
}
223232

233+
void CublasGemm::set_bias(void* bias) {
234+
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
235+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
236+
matmul_desc_,
237+
CUBLASLT_MATMUL_DESC_EPILOGUE,
238+
&epilogue,
239+
sizeof(epilogue)));
240+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
241+
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
242+
}
243+
224244
void CublasGemm::run(
225245
cu::CommandEncoder& encoder,
226246
array& out,
@@ -330,9 +350,9 @@ void CublasGemm::execute(
330350
handle_,
331351
matmul_desc_,
332352
&alpha,
333-
a,
353+
b, // a and b are swapped
334354
a_desc_,
335-
b,
355+
a,
336356
b_desc_,
337357
&beta,
338358
c ? c : out,

mlx/backend/cuda/gemms/cublas_gemm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class CublasGemm {
5555
int32_t batch_count,
5656
int64_t batch_stride);
5757

58+
void set_bias(void* bias);
59+
5860
void run(
5961
cu::CommandEncoder& encoder,
6062
array& out,

mlx/backend/cuda/matmul.cpp

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <numeric>
1212

1313
namespace mlx::core {
14+
1415
namespace {
1516

1617
std::tuple<bool, int64_t, array>
@@ -28,41 +29,20 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
2829
}
2930
}
3031

31-
} // namespace
32-
33-
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
34-
nvtx3::scoped_range r("Matmul::eval_gpu");
35-
auto& s = stream();
36-
auto& encoder = cu::get_command_encoder(s);
37-
38-
assert(inputs.size() == 2);
39-
auto& a_pre = inputs[0];
40-
auto& b_pre = inputs[1];
41-
// Return 0s if either input is empty.
42-
if (a_pre.size() == 0 || b_pre.size() == 0) {
43-
array zero(0, a_pre.dtype());
44-
encoder.add_temporary(zero);
45-
fill_gpu(zero, out, s);
46-
return;
47-
}
48-
49-
out.set_data(allocator::malloc(out.nbytes()));
50-
51-
/////////////////////////////////////////////////////////////////////////////
52-
// Init checks and prep
53-
54-
int M = a_pre.shape(-2);
55-
int N = b_pre.shape(-1);
56-
int K = a_pre.shape(-1);
57-
58-
// Keep a vector with copies to be cleared in the completed buffer to release
59-
// the arrays
60-
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
61-
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
62-
63-
/////////////////////////////////////////////////////////////////////////////
32+
void gemm_and_bias(
33+
cu::CommandEncoder& encoder,
34+
int M,
35+
int N,
36+
int K,
37+
bool a_transposed,
38+
int64_t lda,
39+
bool b_transposed,
40+
int64_t ldb,
41+
array& out,
42+
const array& a,
43+
const array& b,
44+
void* bias = nullptr) {
6445
// Check and collapse batch dimensions
65-
6646
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
6747

6848
auto batch_count = out.size() / (M * N);
@@ -79,7 +59,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
7959
batch_shape = {1};
8060
}
8161

82-
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
62+
// Use gemmv when possible
63+
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
8364
cu::gemv(
8465
a,
8566
b,
@@ -95,10 +76,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
9576
return;
9677
}
9778

98-
/////////////////////////////////////////////////////////////////////////////
9979
// Invoke cublasLt
10080
CublasGemm gemm(
101-
cu::device(s.device),
81+
encoder.device(),
10282
a.dtype(),
10383
a_transposed,
10484
M,
@@ -111,9 +91,45 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
11191
batch_shape.back(),
11292
a_batch_strides.back(),
11393
b_batch_strides.back());
94+
if (bias) {
95+
gemm.set_bias(bias);
96+
}
11497
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
11598
}
11699

100+
} // namespace
101+
102+
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
103+
nvtx3::scoped_range r("Matmul::eval_gpu");
104+
auto& s = stream();
105+
auto& encoder = cu::get_command_encoder(s);
106+
107+
assert(inputs.size() == 2);
108+
auto& a_pre = inputs[0];
109+
auto& b_pre = inputs[1];
110+
// Return 0s if either input is empty.
111+
if (a_pre.size() == 0 || b_pre.size() == 0) {
112+
array zero(0, a_pre.dtype());
113+
encoder.add_temporary(zero);
114+
fill_gpu(zero, out, s);
115+
return;
116+
}
117+
118+
out.set_data(allocator::malloc(out.nbytes()));
119+
120+
int M = a_pre.shape(-2);
121+
int N = b_pre.shape(-1);
122+
int K = a_pre.shape(-1);
123+
124+
// Keep a vector with copies to be cleared in the completed buffer to release
125+
// the arrays
126+
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
127+
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
128+
129+
gemm_and_bias(
130+
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
131+
}
132+
117133
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
118134
nvtx3::scoped_range r("AddMM::eval_gpu");
119135
auto& s = stream();
@@ -136,6 +152,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
136152
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
137153
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
138154

155+
/////////////////////////////////////////////////////////////////////////////
156+
// Dispatch to GEMM with epilogue or AddMM
157+
158+
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
159+
out.set_data(allocator::malloc(out.nbytes()));
160+
gemm_and_bias(
161+
encoder,
162+
M,
163+
N,
164+
K,
165+
a_transposed,
166+
lda,
167+
b_transposed,
168+
ldb,
169+
out,
170+
a,
171+
b,
172+
c.data<void>());
173+
return;
174+
}
175+
139176
int64_t ldc;
140177
{
141178
auto stx = c.strides()[c.ndim() - 2];
@@ -177,7 +214,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
177214
}
178215

179216
/////////////////////////////////////////////////////////////////////////////
180-
// Invoke cublasLt
217+
// Invoke cublasLt with AddMM settings
181218

182219
CublasGemm gemm(
183220
cu::device(s.device),

0 commit comments

Comments
 (0)