Skip to content

Commit 099dcc0

Browse files
author
Awni Hannun
authored
Expose to/from fp8 in Python and don't auto-convert fp8 when loading from safetensors (#2985)
1 parent 8654b82 commit 099dcc0

File tree

6 files changed

+63
-38
lines changed

6 files changed

+63
-38
lines changed

mlx/backend/cpu/unary_ops.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,24 +154,12 @@ struct ToFP8 {
154154
struct FromFP8 {
155155
template <int N>
156156
Simd<float, N> operator()(Simd<uint8_t, N> x) {
157-
auto w = Simd<uint32_t, N>(x) << 24;
158-
auto sign = w & 0x80000000;
159-
auto nonsign = w & 0x7FFFFFFF;
160-
161-
auto renorm_shift = clz(nonsign);
162-
renorm_shift = simd::select(
163-
renorm_shift > Simd<uint32_t, N>{4},
164-
renorm_shift - Simd<uint32_t, N>{4},
165-
Simd<uint32_t, N>{0});
166-
167-
Simd<int32_t, N> inf_nan_mask =
168-
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
169-
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
170-
auto result = sign |
171-
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
172-
inf_nan_mask) &
173-
~zero_mask);
174-
return fp32_from_bits(result);
157+
auto v = Simd<uint16_t, N>(x & 127) << 7;
158+
auto converted = *(Simd<float16_t, N>*)(&v);
159+
converted = converted * 256.0;
160+
auto sign = Simd<bool, N>(x & 128);
161+
Simd<float, N> out = select(sign, -converted, converted);
162+
return out;
175163
}
176164
float operator()(uint8_t x) {
177165
return (*this)(Simd<uint8_t, 1>(x)).value;

mlx/io/safetensors.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
9595
} else if (str == ST_C64) {
9696
return complex64;
9797
} else if (str == ST_F8_E4M3) {
98-
// We convert this manually later
9998
return uint8;
10099
} else {
101100
throw std::runtime_error(
@@ -148,16 +147,14 @@ SafetensorsLoad load_safetensors(
148147
const Shape& shape = item.value().at("shape");
149148
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
150149
Dtype type = dtype_from_safetensor_str(dtype);
151-
auto loaded_array = array(
152-
shape,
153-
type,
154-
std::make_shared<Load>(
155-
stream, in_stream, offset + data_offsets.at(0), false),
156-
std::vector<array>{});
157-
if (dtype == ST_F8_E4M3) {
158-
loaded_array = from_fp8(loaded_array, bfloat16, s);
159-
}
160-
res.insert({item.key(), loaded_array});
150+
res.insert(
151+
{item.key(),
152+
array(
153+
shape,
154+
type,
155+
std::make_shared<Load>(
156+
stream, in_stream, offset + data_offsets.at(0), false),
157+
std::vector<array>{})});
161158
}
162159
return {res, metadata_map};
163160
}

python/src/ops.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5479,10 +5479,10 @@ void init_ops(nb::module_& m) {
54795479
If ``w`` is expected to receive gradients, it must be provided in
54805480
non-quantized form.
54815481
5482-
If ``x`` and `w`` are not quantized, their data types must be ``float32``,
5482+
If ``x`` and `w`` are not quantized, their data types must be ``float32``,
54835483
``float16``, or ``bfloat16``.
54845484
If ``w`` is quantized, it must be packed in unsigned integers.
5485-
5485+
54865486
Args:
54875487
x (array): Input array.
54885488
w (array): Weight matrix. If quantized, it is packed in unsigned integers.
@@ -5502,4 +5502,40 @@ void init_ops(nb::module_& m) {
55025502
array: The result of the multiplication of quantized ``x`` with quantized ``w``.
55035503
needed).
55045504
)pbdoc");
5505+
m.def(
5506+
"from_fp8",
5507+
&mx::from_fp8,
5508+
nb::arg(),
5509+
"dtype"_a = mx::bfloat16,
5510+
nb::kw_only(),
5511+
"stream"_a = nb::none(),
5512+
nb::sig(
5513+
"def from_fp8(x: array, dtype: Dtype = bfloat16, *, stream: Union[None, Stream, Device] = None) -> array"),
5514+
R"pbdoc(
5515+
Convert the array from fp8 (e4m3) to another floating-point type.
5516+
5517+
Args:
5518+
x (array): The input fp8 array with type ``uint8``.
5519+
dtype (Dtype): The data type to convert to. Default: ``bfloat16``.
5520+
5521+
Returns:
5522+
array: The array converted from fp8.
5523+
)pbdoc");
5524+
m.def(
5525+
"to_fp8",
5526+
&mx::to_fp8,
5527+
nb::arg(),
5528+
nb::kw_only(),
5529+
"stream"_a = nb::none(),
5530+
nb::sig(
5531+
"def to_fp8(x: array, *, stream: Union[None, Stream, Device] = None) -> array"),
5532+
R"pbdoc(
5533+
Convert the array to fp8 (e4m3) from another floating-point type.
5534+
5535+
Args:
5536+
x (array): The input array.
5537+
5538+
Returns:
5539+
array: The array converted to fp8 with type ``uint8``.
5540+
)pbdoc");
55055541
}

python/tests/cuda_skip.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
cuda_skip = {
2-
"TestLoad.test_load_f8_e4m3",
32
"TestLayers.test_quantized_embedding",
43
# Block masked matmul NYI
54
"TestBlas.test_block_masked_matmul",

python/tests/test_load.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def test_load_f8_e4m3(self):
168168

169169
expected = [
170170
0,
171-
mx.nan,
172-
mx.nan,
171+
448,
172+
-448,
173173
-0.875,
174174
0.4375,
175175
-0.005859,
@@ -179,12 +179,12 @@ def test_load_f8_e4m3(self):
179179
-0.0039,
180180
]
181181
expected = mx.array(expected, dtype=mx.bfloat16)
182-
contents = b'H\x00\x00\x00\x00\x00\x00\x00{"tensor":{"dtype":"F8_E4M3","shape":[10],"data_offsets":[0,10]}} \x00\x7f\xff\xb6.\x83\xba\xba\xbc\x82'
182+
contents = b'H\x00\x00\x00\x00\x00\x00\x00{"tensor":{"dtype":"F8_E4M3","shape":[10],"data_offsets":[0,10]}} \x00~\xfe\xb6.\x83\xba\xba\xbc\x82'
183183
with tempfile.NamedTemporaryFile(suffix=".safetensors") as f:
184184
f.write(contents)
185185
f.seek(0)
186186
out = mx.load(f)["tensor"]
187-
self.assertTrue(mx.allclose(out[0], expected[0], equal_nan=True))
187+
self.assertTrue(mx.allclose(mx.from_fp8(out), expected))
188188

189189
def test_save_and_load_gguf_metadata_basic(self):
190190
if not os.path.isdir(self.test_dir):

python/tests/test_ops.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3197,8 +3197,6 @@ def test_masked_scatter(self):
31973197
)
31983198
)
31993199

3200-
3201-
class TestBroadcast(mlx_tests.MLXTestCase):
32023200
def test_broadcast_shapes(self):
32033201
# Basic broadcasting
32043202
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
@@ -3243,6 +3241,13 @@ def test_sort_nan(self):
32433241
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))
32443242
x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)
32453243

3244+
def test_to_from_fp8(self):
3245+
vals = mx.array(
3246+
[448, 256, 192, 128, 96, 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 0.015625]
3247+
)
3248+
self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals))
3249+
self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals))
3250+
32463251

32473252
if __name__ == "__main__":
32483253
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)