Skip to content
Merged

qqmm #2789

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
3ed3b11
qqmm
nastya236 Nov 18, 2025
56a21b9
merge main
nastya236 Nov 18, 2025
54f9958
merge main
nastya236 Nov 18, 2025
61e30ea
refactoring
nastya236 Nov 19, 2025
9e4879b
refactoring
nastya236 Nov 19, 2025
df45b39
quantize activations on the fly
nastya236 Nov 19, 2025
7a012a7
quantize in eval
nastya236 Nov 20, 2025
e341572
Revert "quantize in eval"
nastya236 Nov 20, 2025
de49f80
Revert "Revert "quantize in eval""
nastya236 Nov 20, 2025
545dded
on the fly activation quantization
nastya236 Nov 20, 2025
5318b38
pre-commit
nastya236 Nov 20, 2025
2184744
qqmm inputs bf16 second arg
nastya236 Nov 24, 2025
61d09ea
fix
nastya236 Nov 24, 2025
6af6eb3
bf16 weights are optional
nastya236 Nov 24, 2025
1e37ef6
op in python, typos
nastya236 Nov 24, 2025
9c584f8
typo
nastya236 Nov 24, 2025
9a83d3c
batched qqmm
nastya236 Nov 25, 2025
88361a3
delete batching
Nov 26, 2025
b9e73ab
string instead of stringz-view
Nov 26, 2025
95c275b
add 2D input condition
nastya236 Nov 26, 2025
64b8cbe
force transpose
nastya236 Nov 26, 2025
4b68595
fix transpose
Nov 26, 2025
ee0ea9f
add pythong tests
Nov 27, 2025
cc0333e
added qq linear
Nov 28, 2025
34f42fb
added tests
nastya236 Nov 29, 2025
9184f9a
docs correctlion
nastya236 Nov 29, 2025
110848f
small fixes
nastya236 Nov 29, 2025
6633c4b
deleted qqlinear for now
nastya236 Nov 29, 2025
a71e436
deleted unused header
nastya236 Nov 29, 2025
4030486
Merge branch 'main' into qq-matmul
nastya236 Nov 29, 2025
b36c6d7
delete debuging print
nastya236 Nov 29, 2025
3b5ef03
Merge branch 'main' into qq-matmul
nastya236 Dec 11, 2025
ca910c8
Merge branch 'main' into qq-matmul
nastya236 Dec 11, 2025
8508ce9
Renamed to QQMatmul, input only w (q/nq), optionaly scales
nastya236 Dec 11, 2025
49a0ba8
validate qqmm types&shapes, pad scales in fp_quantize, swizzled layout
nastya236 Dec 12, 2025
5576929
revert packing scales in tiled layout inside fp_quantize
nastya236 Dec 12, 2025
3cbb18f
adjusted py binding, revert deleted spaces
nastya236 Dec 12, 2025
a3cb081
modified docs, add skip if in tests
nastya236 Dec 13, 2025
32fb891
output type based = the first arg type
nastya236 Dec 14, 2025
5e51e0f
Merge remote-tracking branch 'upstream/main' into qq-matmul
nastya236 Dec 15, 2025
53bde90
compile qqmm only on sm_100
nastya236 Dec 15, 2025
2ed34c7
fix typos
nastya236 Dec 15, 2025
abf44ea
dtype_to_cublas_type double def
nastya236 Dec 15, 2025
463e8ad
refactoring, test -> example
nastya236 Dec 15, 2025
e752e96
both inputs non quatized in vjp
nastya236 Dec 15, 2025
a814438
typo
nastya236 Dec 15, 2025
f541eaa
Merge remote-tracking branch 'upstream/main' into qq-matmul
nastya236 Dec 15, 2025
0087b78
address comments, fix merge conflict
nastya236 Dec 16, 2025
18ef7b9
fix typo
nastya236 Dec 16, 2025
c54d0ef
fixed error on cuda 12.6
nastya236 Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/python/qqmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from itertools import product

import mlx.core as mx


# In mxfp8 mode, the results do not match exactly:
# fewer than 1% of output elements differ.
# This does not appear to be a systematic error.
# The error can exceed 1 ULP for very small values,
# and is always below 1 ULP for larger values.
# For nvfp4, the results match exactly.
# therefore I suspect that the discrepancy comes from
# the mxfp8 matmul implementation in cuBLASLt..
def ulp_bf16_at(x):
ax = mx.abs(x)
min_normal = mx.array(2.0**-126)
ax = mx.where(ax < min_normal, min_normal, ax)
e = mx.floor(mx.log2(ax))
return mx.power(2.0, e - 7.0)


def test_qqmm():
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
dtypes = [mx.bfloat16, mx.float32, mx.float16]

tests = (
(16, "nvfp4", 4),
(32, "mxfp8", 8),
)
shapes = (
[64, 65, 33, 128, 256, 1024, 1024 * 8], # M
[64, 128, 256, 1024, 1024 * 8], # N
[64, 128, 256, 1024, 1024 * 8], # K
)
for group_size, mode, bits in tests:
for M, N, K in product(*shapes):
for dtype in dtypes:
x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype)
w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype)
w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode)
w_dq = mx.dequantize(
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_q = mx.qqmm(
x,
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
)
x_q, scales_x = mx.quantize(
x, group_size=group_size, bits=bits, mode=mode
)
x_dq = mx.dequantize(
x_q,
scales_x,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_hat = mx.matmul(x_dq, mx.transpose(w_dq))
ulp = ulp_bf16_at(y_hat)
error = (y_q - y_hat).abs()
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
raise AssertionError(
f"qqmm test failed for shape {(M, N, K)}, "
f"group_size={group_size}, bits={bits}, "
f"mode={mode}, dtype={dtype}"
)


def test_qqmm_vjp():
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
M = 64
N = 1024
K = 512
tests = (
(16, "nvfp4", 4),
(32, "mxfp8", 8),
)
x = mx.random.normal(shape=(M, K), key=k1)
c = mx.ones(shape=(M, N))

for group_size, mode, bits in tests:
w = mx.random.normal(shape=(N, K), key=k2)

def fn(x):
return mx.qqmm(x, w, group_size=group_size, bits=bits, mode=mode)

_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
w_tq, scales_wt = mx.quantize(
mx.transpose(w), group_size=group_size, bits=bits, mode=mode
)
expected_out = mx.qqmm(
c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode
)
ulp = ulp_bf16_at(expected_out)
error = (vjp_out[0] - expected_out).abs()
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
raise AssertionError(
f"qqmm vjp test failed for shape {(M, N, K)}, "
f"group_size={group_size}, bits={bits}, mode={mode}"
)


if __name__ == "__main__":
test_qqmm()
test_qqmm_vjp()
3 changes: 3 additions & 0 deletions mlx/backend/cpu/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,4 +1145,7 @@ void fast::ConvertFP8::eval_cpu(
});
}

void QQMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("QQMatmul not implemented on CPU.");
}
} // namespace mlx::core
7 changes: 7 additions & 0 deletions mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
Expand Down Expand Up @@ -64,6 +65,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
else()
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu)
endif()

if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
Expand Down
238 changes: 238 additions & 0 deletions mlx/backend/cuda/cublas_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/cuda/cublas_utils.h"
#include "mlx/backend/cuda/cuda.h"
#include "mlx/utils.h"

namespace mlx::core {
namespace cublas_utils {

namespace {

struct CublasPreference {
CublasPreference(cu::Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;

CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}

~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}

cublasLtMatmulPreference_t pref_{nullptr};
};

} // namespace

cublasLtMatmulPreference_t get_preference(cu::Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}

void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) {
if (workspace_size == 0) {
return nullptr;
}

// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(workspace_size, 256) * 256;
array workspace(
cu::malloc_async(nbytes, encoder),
{static_cast<int>(workspace_size)},
int8);
encoder.add_temporary(workspace);
return gpu_ptr<void>(workspace);
}

cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}

} // namespace cublas_utils

CublasMatmulBase::~CublasMatmulBase() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}

void CublasMatmulBase::init_base(
cu::Device& device,
cudaDataType_t scale_type,
cublasComputeType_t compute_type,
cudaDataType_t data_type,
cudaDataType_t output_type,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride) {
M_ = a_rows;
N_ = b_cols;
scale_type_ = scale_type;
handle_ = device.lt_handle();
pref_ = cublas_utils::get_preference(device);
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;

CHECK_CUBLAS_ERROR(
cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type));

int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));

// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&a_op,
sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&b_op,
sizeof(cublasOperation_t)));

a_desc_ = cublas_utils::create_matrix_layout(
data_type,
b_cols,
b_rows,
b_transposed,
ldb,
batch_count,
b_batch_stride);
b_desc_ = cublas_utils::create_matrix_layout(
data_type,
a_cols,
a_rows,
a_transposed,
lda,
batch_count,
a_batch_stride);
out_desc_ = cublas_utils::create_matrix_layout(
output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows);
}

void CublasMatmulBase::execute_matmul(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* c,
const void* alpha_ptr,
const void* beta_ptr) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
handle_,
matmul_desc_,
a_desc_,
b_desc_,
c ? c_desc_ : out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}

void* workspace_ptr =
cublas_utils::allocate_workspace(encoder, heuristic_.workspaceSize);

// Execute matmul
auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
alpha_ptr,
b, // a and b are swapped for row-major layout
a_desc_,
a,
b_desc_,
beta_ptr,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace_ptr,
heuristic_.workspaceSize,
encoder.stream()));
}

void CublasMatmulBase::set_bias(
cu::CommandEncoder& encoder,
const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = gpu_ptr<void>(bias);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
}

} // namespace mlx::core
Loading
Loading