Skip to content

Commit 5e5fe4a

Browse files
authored
Revert "[CUDA] Use GEMM with epilogue instead of AddMM (#2569)"
This reverts commit dde3682.
1 parent dde3682 commit 5e5fe4a

File tree

3 files changed

+51
-110
lines changed

3 files changed

+51
-110
lines changed

mlx/backend/cuda/gemms/cublas_gemm.cpp

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
8585
int32_t batch_count,
8686
int64_t batch_stride) {
8787
cublasLtMatrixLayout_t desc;
88-
if (transposed) {
89-
std::swap(rows, cols);
90-
}
9188
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
89+
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
90+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
91+
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
9292
if (batch_count > 1) {
9393
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
9494
desc,
@@ -138,34 +138,25 @@ CublasGemm::CublasGemm(
138138
CUBLASLT_MATMUL_DESC_POINTER_MODE,
139139
&pointer_mode,
140140
sizeof(int32_t)));
141-
142-
// In cublasLt matrices use column-major layout, while it is possible to use
143-
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
144-
// epilogue does not work with the option. So instead we swap A and B to make
145-
// cublasLt return the row-major result, which works because:
146-
// - the data of a matrix in row-major layout is identical to its transpose in
147-
// column-major layout
148-
// - C^T = (A @ B)^T = B^T @ A^T
149-
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
141+
cublasOperation_t op = CUBLAS_OP_N;
150142
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
151143
matmul_desc_,
152144
CUBLASLT_MATMUL_DESC_TRANSA,
153-
&a_op,
145+
&op,
154146
sizeof(cublasOperation_t)));
155-
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
156147
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
157148
matmul_desc_,
158149
CUBLASLT_MATMUL_DESC_TRANSB,
159-
&b_op,
150+
&op,
160151
sizeof(cublasOperation_t)));
161152

162153
auto type = dtype_to_cublas_type(dtype);
163154
a_desc_ = create_matrix_layout(
164-
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
155+
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
165156
b_desc_ = create_matrix_layout(
166-
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
157+
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
167158
out_desc_ = create_matrix_layout(
168-
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
159+
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
169160
}
170161

171162
CublasGemm::CublasGemm(
@@ -200,7 +191,7 @@ CublasGemm::CublasGemm(
200191
b_batch_stride) {
201192
auto type = dtype_to_cublas_type(dtype);
202193
c_desc_ = create_matrix_layout(
203-
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
194+
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
204195
}
205196

206197
CublasGemm::~CublasGemm() {
@@ -222,25 +213,14 @@ void CublasGemm::set_out(
222213
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
223214
out_desc_ = create_matrix_layout(
224215
dtype_to_cublas_type(dtype),
225-
cols,
226216
rows,
217+
cols,
227218
transposed,
228219
ld,
229220
batch_count,
230221
batch_stride);
231222
}
232223

233-
void CublasGemm::set_bias(void* bias) {
234-
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
235-
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
236-
matmul_desc_,
237-
CUBLASLT_MATMUL_DESC_EPILOGUE,
238-
&epilogue,
239-
sizeof(epilogue)));
240-
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
241-
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
242-
}
243-
244224
void CublasGemm::run(
245225
cu::CommandEncoder& encoder,
246226
array& out,
@@ -350,9 +330,9 @@ void CublasGemm::execute(
350330
handle_,
351331
matmul_desc_,
352332
&alpha,
353-
b, // a and b are swapped
354-
a_desc_,
355333
a,
334+
a_desc_,
335+
b,
356336
b_desc_,
357337
&beta,
358338
c ? c : out,

mlx/backend/cuda/gemms/cublas_gemm.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ class CublasGemm {
5555
int32_t batch_count,
5656
int64_t batch_stride);
5757

58-
void set_bias(void* bias);
59-
6058
void run(
6159
cu::CommandEncoder& encoder,
6260
array& out,

mlx/backend/cuda/matmul.cpp

Lines changed: 38 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <numeric>
1212

1313
namespace mlx::core {
14-
1514
namespace {
1615

1716
std::tuple<bool, int64_t, array>
@@ -29,20 +28,41 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
2928
}
3029
}
3130

32-
void gemm_and_bias(
33-
cu::CommandEncoder& encoder,
34-
int M,
35-
int N,
36-
int K,
37-
bool a_transposed,
38-
int64_t lda,
39-
bool b_transposed,
40-
int64_t ldb,
41-
array& out,
42-
const array& a,
43-
const array& b,
44-
void* bias = nullptr) {
31+
} // namespace
32+
33+
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
34+
nvtx3::scoped_range r("Matmul::eval_gpu");
35+
auto& s = stream();
36+
auto& encoder = cu::get_command_encoder(s);
37+
38+
assert(inputs.size() == 2);
39+
auto& a_pre = inputs[0];
40+
auto& b_pre = inputs[1];
41+
// Return 0s if either input is empty.
42+
if (a_pre.size() == 0 || b_pre.size() == 0) {
43+
array zero(0, a_pre.dtype());
44+
encoder.add_temporary(zero);
45+
fill_gpu(zero, out, s);
46+
return;
47+
}
48+
49+
out.set_data(allocator::malloc(out.nbytes()));
50+
51+
/////////////////////////////////////////////////////////////////////////////
52+
// Init checks and prep
53+
54+
int M = a_pre.shape(-2);
55+
int N = b_pre.shape(-1);
56+
int K = a_pre.shape(-1);
57+
58+
// Keep a vector with copies to be cleared in the completed buffer to release
59+
// the arrays
60+
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
61+
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
62+
63+
/////////////////////////////////////////////////////////////////////////////
4564
// Check and collapse batch dimensions
65+
4666
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
4767

4868
auto batch_count = out.size() / (M * N);
@@ -59,8 +79,7 @@ void gemm_and_bias(
5979
batch_shape = {1};
6080
}
6181

62-
// Use gemmv when possible
63-
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
82+
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
6483
cu::gemv(
6584
a,
6685
b,
@@ -76,9 +95,10 @@ void gemm_and_bias(
7695
return;
7796
}
7897

98+
/////////////////////////////////////////////////////////////////////////////
7999
// Invoke cublasLt
80100
CublasGemm gemm(
81-
encoder.device(),
101+
cu::device(s.device),
82102
a.dtype(),
83103
a_transposed,
84104
M,
@@ -91,45 +111,9 @@ void gemm_and_bias(
91111
batch_shape.back(),
92112
a_batch_strides.back(),
93113
b_batch_strides.back());
94-
if (bias) {
95-
gemm.set_bias(bias);
96-
}
97114
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
98115
}
99116

100-
} // namespace
101-
102-
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
103-
nvtx3::scoped_range r("Matmul::eval_gpu");
104-
auto& s = stream();
105-
auto& encoder = cu::get_command_encoder(s);
106-
107-
assert(inputs.size() == 2);
108-
auto& a_pre = inputs[0];
109-
auto& b_pre = inputs[1];
110-
// Return 0s if either input is empty.
111-
if (a_pre.size() == 0 || b_pre.size() == 0) {
112-
array zero(0, a_pre.dtype());
113-
encoder.add_temporary(zero);
114-
fill_gpu(zero, out, s);
115-
return;
116-
}
117-
118-
out.set_data(allocator::malloc(out.nbytes()));
119-
120-
int M = a_pre.shape(-2);
121-
int N = b_pre.shape(-1);
122-
int K = a_pre.shape(-1);
123-
124-
// Keep a vector with copies to be cleared in the completed buffer to release
125-
// the arrays
126-
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
127-
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
128-
129-
gemm_and_bias(
130-
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
131-
}
132-
133117
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
134118
nvtx3::scoped_range r("AddMM::eval_gpu");
135119
auto& s = stream();
@@ -152,27 +136,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
152136
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
153137
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
154138

155-
/////////////////////////////////////////////////////////////////////////////
156-
// Dispatch to GEMM with epilogue or AddMM
157-
158-
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
159-
out.set_data(allocator::malloc(out.nbytes()));
160-
gemm_and_bias(
161-
encoder,
162-
M,
163-
N,
164-
K,
165-
a_transposed,
166-
lda,
167-
b_transposed,
168-
ldb,
169-
out,
170-
a,
171-
b,
172-
c.data<void>());
173-
return;
174-
}
175-
176139
int64_t ldc;
177140
{
178141
auto stx = c.strides()[c.ndim() - 2];
@@ -214,7 +177,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
214177
}
215178

216179
/////////////////////////////////////////////////////////////////////////////
217-
// Invoke cublasLt with AddMM settings
180+
// Invoke cublasLt
218181

219182
CublasGemm gemm(
220183
cu::device(s.device),

0 commit comments

Comments
 (0)