Skip to content

Commit 4cf5b29

Browse files
nastya236rootrootroot
authored
qqmm (#2789)
Co-authored-by: root <root@bolt-t9a77vmteu-94s9t6ymth.bolt-pods.turi-bolt.svc.cluster.local> Co-authored-by: root <root@bolt-5azkyvd8ga-kgfzk84y6m.bolt-pods.turi-bolt.svc.cluster.local> Co-authored-by: root <root@bolt-y4nktpaecv-ssnx24rdha.bolt-pods.turi-bolt.svc.cluster.local>
1 parent 6b330eb commit 4cf5b29

File tree

22 files changed

+1502
-223
lines changed

22 files changed

+1502
-223
lines changed

examples/python/qqmm.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from itertools import product
2+
3+
import mlx.core as mx
4+
5+
6+
# In mxfp8 mode, the results do not match exactly:
7+
# fewer than 1% of output elements differ.
8+
# This does not appear to be a systematic error.
9+
# The error can exceed 1 ULP for very small values,
10+
# and is always below 1 ULP for larger values.
11+
# For nvfp4, the results match exactly.
12+
# therefore I suspect that the discrepancy comes from
13+
# the mxfp8 matmul implementation in cuBLASLt..
14+
def ulp_bf16_at(x):
15+
ax = mx.abs(x)
16+
min_normal = mx.array(2.0**-126)
17+
ax = mx.where(ax < min_normal, min_normal, ax)
18+
e = mx.floor(mx.log2(ax))
19+
return mx.power(2.0, e - 7.0)
20+
21+
22+
def test_qqmm():
23+
key = mx.random.key(0)
24+
k1, k2 = mx.random.split(key)
25+
dtypes = [mx.bfloat16, mx.float32, mx.float16]
26+
27+
tests = (
28+
(16, "nvfp4", 4),
29+
(32, "mxfp8", 8),
30+
)
31+
shapes = (
32+
[64, 65, 33, 128, 256, 1024, 1024 * 8], # M
33+
[64, 128, 256, 1024, 1024 * 8], # N
34+
[64, 128, 256, 1024, 1024 * 8], # K
35+
)
36+
for group_size, mode, bits in tests:
37+
for M, N, K in product(*shapes):
38+
for dtype in dtypes:
39+
x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype)
40+
w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype)
41+
w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode)
42+
w_dq = mx.dequantize(
43+
w_q,
44+
scales_w,
45+
group_size=group_size,
46+
bits=bits,
47+
mode=mode,
48+
dtype=dtype,
49+
)
50+
y_q = mx.qqmm(
51+
x,
52+
w_q,
53+
scales_w,
54+
group_size=group_size,
55+
bits=bits,
56+
mode=mode,
57+
)
58+
x_q, scales_x = mx.quantize(
59+
x, group_size=group_size, bits=bits, mode=mode
60+
)
61+
x_dq = mx.dequantize(
62+
x_q,
63+
scales_x,
64+
group_size=group_size,
65+
bits=bits,
66+
mode=mode,
67+
dtype=dtype,
68+
)
69+
y_hat = mx.matmul(x_dq, mx.transpose(w_dq))
70+
ulp = ulp_bf16_at(y_hat)
71+
error = (y_q - y_hat).abs()
72+
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
73+
raise AssertionError(
74+
f"qqmm test failed for shape {(M, N, K)}, "
75+
f"group_size={group_size}, bits={bits}, "
76+
f"mode={mode}, dtype={dtype}"
77+
)
78+
79+
80+
def test_qqmm_vjp():
81+
key = mx.random.key(0)
82+
k1, k2 = mx.random.split(key)
83+
M = 64
84+
N = 1024
85+
K = 512
86+
tests = (
87+
(16, "nvfp4", 4),
88+
(32, "mxfp8", 8),
89+
)
90+
x = mx.random.normal(shape=(M, K), key=k1)
91+
c = mx.ones(shape=(M, N))
92+
93+
for group_size, mode, bits in tests:
94+
w = mx.random.normal(shape=(N, K), key=k2)
95+
96+
def fn(x):
97+
return mx.qqmm(x, w, group_size=group_size, bits=bits, mode=mode)
98+
99+
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
100+
w_tq, scales_wt = mx.quantize(
101+
mx.transpose(w), group_size=group_size, bits=bits, mode=mode
102+
)
103+
expected_out = mx.qqmm(
104+
c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode
105+
)
106+
ulp = ulp_bf16_at(expected_out)
107+
error = (vjp_out[0] - expected_out).abs()
108+
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
109+
raise AssertionError(
110+
f"qqmm vjp test failed for shape {(M, N, K)}, "
111+
f"group_size={group_size}, bits={bits}, mode={mode}"
112+
)
113+
114+
115+
if __name__ == "__main__":
116+
test_qqmm()
117+
test_qqmm_vjp()

mlx/backend/cpu/quantized.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,4 +1145,7 @@ void fast::ConvertFP8::eval_cpu(
11451145
});
11461146
}
11471147

1148+
void QQMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
1149+
throw std::runtime_error("QQMatmul not implemented on CPU.");
1150+
}
11481151
} // namespace mlx::core

mlx/backend/cuda/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ target_sources(
1818
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
1919
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
2020
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
21+
${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp
2122
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
2223
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
2324
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
@@ -64,6 +65,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
6465
# fp4 is not available on < 12.8
6566
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
6667
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
68+
else()
69+
target_sources(
70+
mlx
71+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
72+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp
73+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu)
6774
endif()
6875

6976
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)

mlx/backend/cuda/cublas_utils.cpp

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/cublas_utils.h"
4+
#include "mlx/backend/cuda/cuda.h"
5+
#include "mlx/utils.h"
6+
7+
namespace mlx::core {
8+
namespace cublas_utils {
9+
10+
namespace {
11+
12+
struct CublasPreference {
13+
CublasPreference(cu::Device& device) {
14+
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
15+
// for Hopper+:
16+
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
17+
uint64_t MiB = 1024 * 1024;
18+
uint64_t workspace_size =
19+
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
20+
21+
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
22+
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
23+
pref_,
24+
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
25+
&workspace_size,
26+
sizeof(uint64_t)));
27+
}
28+
29+
~CublasPreference() {
30+
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
31+
}
32+
33+
cublasLtMatmulPreference_t pref_{nullptr};
34+
};
35+
36+
} // namespace
37+
38+
cublasLtMatmulPreference_t get_preference(cu::Device& device) {
39+
static CublasPreference pref(device);
40+
return pref.pref_;
41+
}
42+
43+
void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) {
44+
if (workspace_size == 0) {
45+
return nullptr;
46+
}
47+
48+
// Ensure workspace is 256-byte aligned
49+
int nbytes = cuda::ceil_div(workspace_size, 256) * 256;
50+
array workspace(
51+
cu::malloc_async(nbytes, encoder),
52+
{static_cast<int>(workspace_size)},
53+
int8);
54+
encoder.add_temporary(workspace);
55+
return gpu_ptr<void>(workspace);
56+
}
57+
58+
cublasLtMatrixLayout_t create_matrix_layout(
59+
cudaDataType_t type,
60+
uint64_t rows,
61+
uint64_t cols,
62+
bool transposed,
63+
int64_t ld,
64+
int32_t batch_count,
65+
int64_t batch_stride) {
66+
cublasLtMatrixLayout_t desc;
67+
if (transposed) {
68+
std::swap(rows, cols);
69+
}
70+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
71+
if (batch_count > 1) {
72+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
73+
desc,
74+
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
75+
&batch_count,
76+
sizeof(int32_t)));
77+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
78+
desc,
79+
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
80+
&batch_stride,
81+
sizeof(int64_t)));
82+
}
83+
return desc;
84+
}
85+
86+
} // namespace cublas_utils
87+
88+
CublasMatmulBase::~CublasMatmulBase() {
89+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
90+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
91+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
92+
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
93+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
94+
}
95+
96+
void CublasMatmulBase::init_base(
97+
cu::Device& device,
98+
cudaDataType_t scale_type,
99+
cublasComputeType_t compute_type,
100+
cudaDataType_t data_type,
101+
cudaDataType_t output_type,
102+
bool a_transposed,
103+
uint64_t a_rows,
104+
uint64_t a_cols,
105+
int64_t lda,
106+
bool b_transposed,
107+
uint64_t b_rows,
108+
uint64_t b_cols,
109+
int64_t ldb,
110+
int32_t batch_count,
111+
int64_t a_batch_stride,
112+
int64_t b_batch_stride) {
113+
M_ = a_rows;
114+
N_ = b_cols;
115+
scale_type_ = scale_type;
116+
handle_ = device.lt_handle();
117+
pref_ = cublas_utils::get_preference(device);
118+
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
119+
120+
CHECK_CUBLAS_ERROR(
121+
cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type));
122+
123+
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
124+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
125+
matmul_desc_,
126+
CUBLASLT_MATMUL_DESC_POINTER_MODE,
127+
&pointer_mode,
128+
sizeof(int32_t)));
129+
130+
// In cublasLt matrices use column-major layout, while it is possible to use
131+
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
132+
// epilogue does not work with the option. So instead we swap A and B to make
133+
// cublasLt return the row-major result, which works because:
134+
// - the data of a matrix in row-major layout is identical to its transpose in
135+
// column-major layout
136+
// - C^T = (A @ B)^T = B^T @ A^T
137+
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
138+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
139+
matmul_desc_,
140+
CUBLASLT_MATMUL_DESC_TRANSA,
141+
&a_op,
142+
sizeof(cublasOperation_t)));
143+
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
144+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
145+
matmul_desc_,
146+
CUBLASLT_MATMUL_DESC_TRANSB,
147+
&b_op,
148+
sizeof(cublasOperation_t)));
149+
150+
a_desc_ = cublas_utils::create_matrix_layout(
151+
data_type,
152+
b_cols,
153+
b_rows,
154+
b_transposed,
155+
ldb,
156+
batch_count,
157+
b_batch_stride);
158+
b_desc_ = cublas_utils::create_matrix_layout(
159+
data_type,
160+
a_cols,
161+
a_rows,
162+
a_transposed,
163+
lda,
164+
batch_count,
165+
a_batch_stride);
166+
out_desc_ = cublas_utils::create_matrix_layout(
167+
output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows);
168+
}
169+
170+
void CublasMatmulBase::execute_matmul(
171+
cu::CommandEncoder& encoder,
172+
void* out,
173+
const void* a,
174+
const void* b,
175+
const void* c,
176+
const void* alpha_ptr,
177+
const void* beta_ptr) {
178+
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
179+
int ret = 0;
180+
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
181+
handle_,
182+
matmul_desc_,
183+
a_desc_,
184+
b_desc_,
185+
c ? c_desc_ : out_desc_,
186+
out_desc_,
187+
pref_,
188+
1,
189+
&heuristic_,
190+
&ret));
191+
if (ret == 0) {
192+
throw std::runtime_error("Can not find algorithm for matmul.");
193+
}
194+
}
195+
196+
void* workspace_ptr =
197+
cublas_utils::allocate_workspace(encoder, heuristic_.workspaceSize);
198+
199+
// Execute matmul
200+
auto capture = encoder.capture_context();
201+
CHECK_CUBLAS_ERROR(cublasLtMatmul(
202+
handle_,
203+
matmul_desc_,
204+
alpha_ptr,
205+
b, // a and b are swapped for row-major layout
206+
a_desc_,
207+
a,
208+
b_desc_,
209+
beta_ptr,
210+
c ? c : out,
211+
c ? c_desc_ : out_desc_,
212+
out,
213+
out_desc_,
214+
&heuristic_.algo,
215+
workspace_ptr,
216+
heuristic_.workspaceSize,
217+
encoder.stream()));
218+
}
219+
220+
void CublasMatmulBase::set_bias(
221+
cu::CommandEncoder& encoder,
222+
const array& bias) {
223+
encoder.set_input_array(bias);
224+
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
225+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
226+
matmul_desc_,
227+
CUBLASLT_MATMUL_DESC_EPILOGUE,
228+
&epilogue,
229+
sizeof(epilogue)));
230+
auto* bias_ptr = gpu_ptr<void>(bias);
231+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
232+
matmul_desc_,
233+
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
234+
&bias_ptr,
235+
sizeof(bias_ptr)));
236+
}
237+
238+
} // namespace mlx::core

0 commit comments

Comments
 (0)