Skip to content

Commit 7a012a7

Browse files
committed
quantize in eval
1 parent df45b39 commit 7a012a7

File tree

2 files changed

+26
-31
lines changed

2 files changed

+26
-31
lines changed

mlx/backend/cuda/quantized/quantized.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,16 @@ void DualQuantizedMatmul::eval_gpu(
180180
auto& s = stream();
181181
auto& encoder = cu::get_command_encoder(s);
182182

183-
assert(inputs.size() == 4);
183+
assert(inputs.size() == 3);
184184
auto& a_pre = inputs[0]; // activations are not quantized, only weights are
185185
auto& b = inputs[1];
186186

187-
auto a_q = quantize(a_pre, group_size_, bits_, mode_, s);
187+
auto a_q = fp_quantize(
188+
a_pre,
189+
group_size_,
190+
bits_,
191+
mode_,
192+
s); // here i assume that ist is only for nvfp4/mxfp8
188193
encoder.add_temporary(a_q[0]);
189194
encoder.add_temporary(a_q[1]);
190195

mlx/ops.cpp

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -154,43 +154,41 @@ std::pair<int, int> extract_quantized_matmul_dims(
154154
return {w_inner_dims, w_outer_dims};
155155
}
156156

157-
std::pair<std::pair<int, int>, std::pair<int, int>> extract_qqmm_dims(
157+
std::pair<int, int> extract_qqmm_dims(
158158
std::string_view tag,
159159
const array& x,
160160
const array& w,
161-
const array& scales_x,
162161
const array& scales_w,
163162
bool transpose,
164163
int group_size,
165164
int bits) {
166-
// Validate x and scales_x
167-
validate_quantized_input(
168-
tag, x, scales_x, "x matrix", "scales_x", group_size, bits);
169-
170165
// Validate w and scales_w
171166
validate_quantized_input(
172167
tag, w, scales_w, "weight matrix", "scales_w", group_size, bits);
173168

174169
// For narrow precision types (mxfp4, nvfp4) the only supported layout is TN
175170
// A is MxK, B is NxK (transposed)
176-
int x_inner_dims = x.shape(-1); // K // (32 / bits)
177-
int x_outer_dims = x.shape(-2); // M
171+
int x_inner_dims = x.shape(-1) / (32 / bits); // K
178172

179173
// Calculate the expanded w's dimensions
180174
int w_inner_dims = (transpose) ? w.shape(-1) : w.shape(-2);
181175
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1);
182176

183177
if (w_inner_dims != x_inner_dims) {
184178
std::ostringstream msg;
185-
msg << "[" << tag << "] Last dimension of first quantized input with "
186-
<< "shape (..., " << x_inner_dims << ") does not match "
187-
<< "the quantized matrix (" << w_inner_dims << ", " << w_outer_dims
188-
<< ") computed with transpose=" << std::boolalpha << transpose;
179+
msg << "[" << tag << "] Inner dimension of second input with "
180+
<< "shape (" << w_inner_dims << ", " << w_outer_dims << ")"
181+
<< " computed with transpose=" << std::boolalpha << transpose
182+
<< " does not match the packed inner dimension of the first"
183+
<< "input (...," << x_inner_dims << ") computed with bits=" << bits
184+
<< " and transpose=" << std::boolalpha << transpose;
189185

190186
throw std::invalid_argument(msg.str());
191187
}
192188

193-
return {{x_inner_dims, x_outer_dims}, {w_inner_dims, w_outer_dims}};
189+
return {
190+
w_inner_dims, w_outer_dims
191+
}
194192
}
195193

196194
} // namespace
@@ -4231,9 +4229,8 @@ array qqmm(
42314229
StreamOrDevice s /* = {} */) {
42324230
// currently only simetric quantization is supported for qqmm
42334231
auto qmode = string_to_quantization_mode(mode, "qqmm");
4234-
// For narrow precision MMAs on B200 only TN layout is supported:
4235-
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html
4236-
// TODO: handle it better
4232+
// here we need to check that inputs and otputs will be quantized in the same
4233+
// way...
42374234
if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) &&
42384235
!transpose) {
42394236
std::ostringstream msg;
@@ -4247,21 +4244,14 @@ array qqmm(
42474244
msg << "[qqmm] Affine quantization is not supported for qqmm.";
42484245
throw std::invalid_argument(msg.str());
42494246
}
4250-
auto quantized_x = quantize(x, group_size_, bits, mode, s);
4251-
auto x_q = quantized_x[0];
4252-
auto scales_x = quantized_x[1];
4253-
encoder.add_temporary(x_q);
4254-
encoder.add_temporary(scales_x);
42554247
auto [group_size, bits] =
42564248
quantization_params_from_mode(qmode, group_size_, bits_);
4257-
// Check and extract the quantized matrix shape against x
4258-
auto [x_dims, w_dims] = extract_qqmm_dims(
4259-
"qqmm", x_q, w_q, scales_x, scales_w, transpose, group_size, bits);
4260-
auto [x_inner_dims, x_outer_dims] = x_dims;
4261-
auto [w_inner_dims, w_outer_dims] = w_dims;
4249+
//
4250+
auto [w_inner_dims, w_outer_dims] =
4251+
extract_qqmm_dims("qqmm", x, w_q, scales_w, transpose, group_size, bits);
42624252

4263-
std::vector<array> inputs = {x_q, w_q, scales_x, scales_w};
4264-
if (x_q.ndim() > 2 && w_q.ndim() > 2) {
4253+
std::vector<array> inputs = {x, w_q, scales_w};
4254+
if (x.ndim() > 2 && w_q.ndim() > 2) {
42654255
inputs = broadcast_arrays(inputs, {-2, -1}, s);
42664256
}
42674257

@@ -5969,4 +5959,4 @@ array contiguous(
59695959
{a});
59705960
}
59715961

5972-
} // namespace mlx::core
5962+
} // namespace mlx::core

0 commit comments

Comments
 (0)