@@ -154,43 +154,41 @@ std::pair<int, int> extract_quantized_matmul_dims(
154154 return {w_inner_dims, w_outer_dims};
155155}
156156
157- std::pair<std::pair< int , int >, std::pair< int , int > > extract_qqmm_dims (
157+ std::pair<int , int > extract_qqmm_dims (
158158 std::string_view tag,
159159 const array& x,
160160 const array& w,
161- const array& scales_x,
162161 const array& scales_w,
163162 bool transpose,
164163 int group_size,
165164 int bits) {
166- // Validate x and scales_x
167- validate_quantized_input (
168- tag, x, scales_x, " x matrix" , " scales_x" , group_size, bits);
169-
170165 // Validate w and scales_w
171166 validate_quantized_input (
172167 tag, w, scales_w, " weight matrix" , " scales_w" , group_size, bits);
173168
174169 // For narrow precision types (mxfp4, nvfp4) the only supported layout is TN
175170 // A is MxK, B is NxK (transposed)
176- int x_inner_dims = x.shape (-1 ); // K // (32 / bits)
177- int x_outer_dims = x.shape (-2 ); // M
171+ int x_inner_dims = x.shape (-1 ) / (32 / bits); // K
178172
179173 // Calculate the expanded w's dimensions
180174 int w_inner_dims = (transpose) ? w.shape (-1 ) : w.shape (-2 );
181175 int w_outer_dims = (transpose) ? w.shape (-2 ) : w.shape (-1 );
182176
183177 if (w_inner_dims != x_inner_dims) {
184178 std::ostringstream msg;
185- msg << " [" << tag << " ] Last dimension of first quantized input with "
186- << " shape (..., " << x_inner_dims << " ) does not match "
187- << " the quantized matrix (" << w_inner_dims << " , " << w_outer_dims
188- << " ) computed with transpose=" << std::boolalpha << transpose;
179+ msg << " [" << tag << " ] Inner dimension of second input with "
180+ << " shape (" << w_inner_dims << " , " << w_outer_dims << " )"
181+ << " computed with transpose=" << std::boolalpha << transpose
182+ << " does not match the packed inner dimension of the first"
183+ << " input (...," << x_inner_dims << " ) computed with bits=" << bits
184+ << " and transpose=" << std::boolalpha << transpose;
189185
190186 throw std::invalid_argument (msg.str ());
191187 }
192188
193- return {{x_inner_dims, x_outer_dims}, {w_inner_dims, w_outer_dims}};
189+ return {
190+ w_inner_dims, w_outer_dims
191+ }
194192}
195193
196194} // namespace
@@ -4231,9 +4229,8 @@ array qqmm(
42314229 StreamOrDevice s /* = {} */ ) {
42324230 // currently only simetric quantization is supported for qqmm
42334231 auto qmode = string_to_quantization_mode (mode, " qqmm" );
4234- // For narrow precision MMAs on B200 only TN layout is supported:
4235- // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html
4236- // TODO: handle it better
4232+ // here we need to check that inputs and otputs will be quantized in the same
4233+ // way...
42374234 if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) &&
42384235 !transpose) {
42394236 std::ostringstream msg;
@@ -4247,21 +4244,14 @@ array qqmm(
42474244 msg << " [qqmm] Affine quantization is not supported for qqmm." ;
42484245 throw std::invalid_argument (msg.str ());
42494246 }
4250- auto quantized_x = quantize (x, group_size_, bits, mode, s);
4251- auto x_q = quantized_x[0 ];
4252- auto scales_x = quantized_x[1 ];
4253- encoder.add_temporary (x_q);
4254- encoder.add_temporary (scales_x);
42554247 auto [group_size, bits] =
42564248 quantization_params_from_mode (qmode, group_size_, bits_);
4257- // Check and extract the quantized matrix shape against x
4258- auto [x_dims, w_dims] = extract_qqmm_dims (
4259- " qqmm" , x_q, w_q, scales_x, scales_w, transpose, group_size, bits);
4260- auto [x_inner_dims, x_outer_dims] = x_dims;
4261- auto [w_inner_dims, w_outer_dims] = w_dims;
4249+ //
4250+ auto [w_inner_dims, w_outer_dims] =
4251+ extract_qqmm_dims (" qqmm" , x, w_q, scales_w, transpose, group_size, bits);
42624252
4263- std::vector<array> inputs = {x_q , w_q, scales_x , scales_w};
4264- if (x_q .ndim () > 2 && w_q.ndim () > 2 ) {
4253+ std::vector<array> inputs = {x , w_q, scales_w};
4254+ if (x .ndim () > 2 && w_q.ndim () > 2 ) {
42654255 inputs = broadcast_arrays (inputs, {-2 , -1 }, s);
42664256 }
42674257
@@ -5969,4 +5959,4 @@ array contiguous(
59695959 {a});
59705960}
59715961
5972- } // namespace mlx::core
5962+ } // namespace mlx::core
0 commit comments