@@ -4335,7 +4335,7 @@ std::pair<int, int> extract_qqmm_dims(
43354335}
43364336
43374337array qqmm (
4338- array x ,
4338+ array in_x ,
43394339 array w,
43404340 std::optional<array> scales_w,
43414341 std::optional<int > group_size_ /* = std::nullopt */ ,
@@ -4360,6 +4360,16 @@ array qqmm(
43604360 // 2. w is not quantized, scales is not provided
43614361 auto [group_size, bits] =
43624362 quantization_params_from_mode (qmode, group_size_, bits_);
4363+
4364+ // Allow gemv
4365+ auto x = in_x;
4366+ if (x.ndim () == 1 ) {
4367+ // Insert a singleton dim in the beginning
4368+ x = expand_dims (x, 0 , s);
4369+ } else if (w.ndim () == 2 && x.ndim () > 2 ) {
4370+ x = flatten (x, 0 , -2 , s);
4371+ }
4372+
43634373 // validate inputs
43644374 validate_qqmm_inputs (x, w, scales_w, group_size, bits);
43654375 // validate and extract shapes
@@ -4374,11 +4384,19 @@ array qqmm(
43744384 }
43754385 auto out_shape = inputs[0 ].shape ();
43764386 out_shape.back () = w_outer_dims;
4377- return array (
4387+ auto out = array (
43784388 std::move (out_shape),
43794389 x.dtype (), // output dtype is the same as x dtype
43804390 std::make_shared<QQMatmul>(stream, group_size, bits, qmode),
43814391 std::move (inputs));
4392+ if (in_x.ndim () > 2 ) {
4393+ auto orig_shape = in_x.shape ();
4394+ orig_shape.pop_back ();
4395+ out = unflatten (out, 0 , std::move (orig_shape), s);
4396+ } else if (in_x.ndim () == 1 ) {
4397+ out = squeeze (out, 0 , s);
4398+ }
4399+ return out;
43824400}
43834401
43844402array pack_and_quantize (
0 commit comments