Skip to content

Commit ac26a4c

Browse files
author
Awni Hannun
authored
Allow some non 2D inputs in qqmm (#2981)
1 parent 099dcc0 commit ac26a4c

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

mlx/ops.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4335,7 +4335,7 @@ std::pair<int, int> extract_qqmm_dims(
43354335
}
43364336

43374337
array qqmm(
4338-
array x,
4338+
array in_x,
43394339
array w,
43404340
std::optional<array> scales_w,
43414341
std::optional<int> group_size_ /* = std::nullopt */,
@@ -4360,6 +4360,16 @@ array qqmm(
43604360
// 2. w is not quantized, scales is not provided
43614361
auto [group_size, bits] =
43624362
quantization_params_from_mode(qmode, group_size_, bits_);
4363+
4364+
// Allow gemv
4365+
auto x = in_x;
4366+
if (x.ndim() == 1) {
4367+
// Insert a singleton dim in the beginning
4368+
x = expand_dims(x, 0, s);
4369+
} else if (w.ndim() == 2 && x.ndim() > 2) {
4370+
x = flatten(x, 0, -2, s);
4371+
}
4372+
43634373
// validate inputs
43644374
validate_qqmm_inputs(x, w, scales_w, group_size, bits);
43654375
// validate and extract shapes
@@ -4374,11 +4384,19 @@ array qqmm(
43744384
}
43754385
auto out_shape = inputs[0].shape();
43764386
out_shape.back() = w_outer_dims;
4377-
return array(
4387+
auto out = array(
43784388
std::move(out_shape),
43794389
x.dtype(), // output dtype is the same as x dtype
43804390
std::make_shared<QQMatmul>(stream, group_size, bits, qmode),
43814391
std::move(inputs));
4392+
if (in_x.ndim() > 2) {
4393+
auto orig_shape = in_x.shape();
4394+
orig_shape.pop_back();
4395+
out = unflatten(out, 0, std::move(orig_shape), s);
4396+
} else if (in_x.ndim() == 1) {
4397+
out = squeeze(out, 0, s);
4398+
}
4399+
return out;
43824400
}
43834401

43844402
array pack_and_quantize(

0 commit comments

Comments
 (0)