Skip to content

Commit daaa3af

Browse files
committed
Add F4E2M1FN type: python interface
1 parent c479f09 commit daaa3af

File tree

17 files changed

+79
-31
lines changed

17 files changed

+79
-31
lines changed

xla/pjrt/c/pjrt_c_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,9 @@ typedef enum {
649649
// More truncated 8 bit floating-point formats.
650650
PJRT_Buffer_Type_F8E4M3,
651651
PJRT_Buffer_Type_F8E3M4,
652+
653+
// 4-bit MX floating-point format.
654+
PJRT_Buffer_Type_F4E2M1FN,
652655
} PJRT_Buffer_Type;
653656

654657
typedef enum {

xla/pjrt/c/pjrt_c_api_helpers.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) {
294294
return PJRT_Buffer_Type::PJRT_Buffer_Type_BF16;
295295
case xla::PrimitiveType::F64:
296296
return PJRT_Buffer_Type::PJRT_Buffer_Type_F64;
297+
case xla::PrimitiveType::F4E2M1FN:
298+
return PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN;
297299
case xla::PrimitiveType::F8E5M2:
298300
return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2;
299301
case xla::PrimitiveType::F8E4M3:
@@ -361,6 +363,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) {
361363
return xla::PrimitiveType::C64;
362364
case PJRT_Buffer_Type::PJRT_Buffer_Type_C128:
363365
return xla::PrimitiveType::C128;
366+
case PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN:
367+
return xla::PrimitiveType::F4E2M1FN;
364368
case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2:
365369
return xla::PrimitiveType::F8E5M2;
366370
case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3:

xla/python/ifrt/dtype.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ std::optional<int> DType::byte_size() const {
3232
case kU2:
3333
case kS4:
3434
case kU4:
35+
case kF4E2M1FN:
3536
// Smaller than a byte.
3637
return std::nullopt;
3738
case kPred:
@@ -77,6 +78,7 @@ std::optional<int> DType::bit_size() const {
7778
return 2;
7879
case kS4:
7980
case kU4:
81+
case kF4E2M1FN:
8082
return 4;
8183
case kPred:
8284
case kS8:
@@ -142,6 +144,7 @@ absl::StatusOr<DType> DType::FromProto(const DTypeProto& dtype_proto) {
142144
CASE(C64);
143145
CASE(C128);
144146
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
147+
// CASE(F4E2M1FN);
145148
// CASE(F8E3M4);
146149
// CASE(F8E4M3);
147150
CASE(F8E4M3FN);
@@ -190,6 +193,7 @@ DTypeProto DType::ToProto() const {
190193
CASE(C64);
191194
CASE(C128);
192195
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
196+
// CASE(F4E2M1FN);
193197
// CASE(F8E3M4);
194198
// CASE(F8E4M3);
195199
CASE(F8E4M3FN);

xla/python/ifrt/dtype.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ class DType {
8989
kF8E5M2 = 19,
9090
kF8E5M2FNUZ = 24,
9191

92-
// Next = 30
92+
// MX floating point types.
93+
kF4E2M1FN = 30,
94+
95+
// Next = 31
9396

9497
// Variable-length string represented as raw bytes, as in `bytes` in Python,
9598
// i.e., no encoding enforcement. String is not support in XLA. DType.Kind

xla/python/ifrt/dtype.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,16 @@ message DTypeProto {
7171
KIND_F8E5M2 = 19;
7272
KIND_F8E5M2FNUZ = 24;
7373

74+
// MX floating point types.
75+
KIND_F4E2M1FN = 30;
76+
7477
// Variable-length string represented as raw bytes, as in `bytes` in Python,
7578
// i.e., no encoding enforcement. String is not support in XLA. DType.Kind
7679
// needs to match xla.PrimitiveType enum, so choose a large enum to avoid
7780
// collision.
7881
KIND_STRING = 99;
82+
83+
// Next: 31
7984
}
8085
// LINT.ThenChange()
8186
Kind kind = 1;

xla/python/ifrt/dtype_test.cc

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,21 @@ TEST(DTypeTest, FromToFromProto) {
4242
TEST(DTypeTest, ByteSize) {
4343
for (const auto& [kind, byte_size] :
4444
std::vector<std::tuple<DType::Kind, int>>({
45-
{DType::kS2, -1},
46-
{DType::kU2, -1},
47-
{DType::kS4, -1},
48-
{DType::kU4, -1},
49-
{DType::kPred, 1},
50-
{DType::kS8, 1},
51-
{DType::kU8, 1},
52-
{DType::kF8E3M4, 1},
53-
{DType::kF8E4M3, 1},
54-
{DType::kF8E4M3FN, 1},
55-
{DType::kF8E4M3B11FNUZ, 1},
56-
{DType::kF8E4M3FNUZ, 1},
57-
{DType::kF8E5M2, 1},
58-
{DType::kF8E5M2FNUZ, 1},
59-
{DType::kS16, 2},
60-
{DType::kU16, 2},
61-
{DType::kF16, 2},
62-
{DType::kBF16, 2},
63-
{DType::kS32, 4},
64-
{DType::kU32, 4},
65-
{DType::kF32, 4},
66-
{DType::kS64, 8},
67-
{DType::kU64, 8},
68-
{DType::kF64, 8},
69-
{DType::kC64, 8},
70-
{DType::kC128, 16},
71-
{DType::kToken, -1},
72-
{DType::kInvalid, -1},
73-
{DType::kString, -1},
45+
{DType::kS2, -1}, {DType::kU2, -1},
46+
{DType::kS4, -1}, {DType::kU4, -1},
47+
{DType::kPred, 1}, {DType::kS8, 1},
48+
{DType::kU8, 1}, {DType::kF4E2M1FN, -1},
49+
{DType::kF8E3M4, 1}, {DType::kF8E4M3, 1},
50+
{DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1},
51+
{DType::kF8E4M3FNUZ, 1}, {DType::kF8E5M2, 1},
52+
{DType::kF8E5M2FNUZ, 1}, {DType::kS16, 2},
53+
{DType::kU16, 2}, {DType::kF16, 2},
54+
{DType::kBF16, 2}, {DType::kS32, 4},
55+
{DType::kU32, 4}, {DType::kF32, 4},
56+
{DType::kS64, 8}, {DType::kU64, 8},
57+
{DType::kF64, 8}, {DType::kC64, 8},
58+
{DType::kC128, 16}, {DType::kToken, -1},
59+
{DType::kInvalid, -1}, {DType::kString, -1},
7460
})) {
7561
EXPECT_EQ(DType(kind).byte_size(),
7662
byte_size == -1 ? std::nullopt : std::make_optional(byte_size));
@@ -87,6 +73,7 @@ TEST(DTypeTest, BitSize) {
8773
{DType::kPred, 8},
8874
{DType::kS8, 8},
8975
{DType::kU8, 8},
76+
{DType::kF4E2M1FN, 4},
9077
{DType::kF8E3M4, 8},
9178
{DType::kF8E4M3, 8},
9279
{DType::kF8E4M3FN, 8},

xla/python/pjrt_ifrt/pjrt_dtype.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ absl::StatusOr<xla::PrimitiveType> ToPrimitiveType(DType dtype) {
4444
CASE(DType::kU16, xla::PrimitiveType::U16);
4545
CASE(DType::kU32, xla::PrimitiveType::U32);
4646
CASE(DType::kU64, xla::PrimitiveType::U64);
47+
CASE(DType::kF4E2M1FN, xla::PrimitiveType::F4E2M1FN);
4748
CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4);
4849
CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3);
4950
CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN);
@@ -83,6 +84,7 @@ absl::StatusOr<DType> ToDType(xla::PrimitiveType primitive_type) {
8384
case xla::PrimitiveType::U16:
8485
case xla::PrimitiveType::U32:
8586
case xla::PrimitiveType::U64:
87+
case xla::PrimitiveType::F4E2M1FN:
8688
case xla::PrimitiveType::F8E3M4:
8789
case xla::PrimitiveType::F8E4M3:
8890
case xla::PrimitiveType::F8E4M3FN:

xla/python/py_values.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
184184
} else if (std::is_same<T, bfloat16>()) {
185185
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
186186
type = BF16;
187+
} else if (std::is_same<T, tsl::float4_e2m1fn>()) {
188+
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
189+
type = F4E2M1FN;
187190
} else if (std::is_same<T, tsl::float8_e3m4>()) {
188191
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
189192
type = F8E3M4;
@@ -398,6 +401,10 @@ absl::StatusOr<DevicePutResultFn> DevicePut(nb::handle arg,
398401
(*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
399402
(*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar<uint32_t>;
400403
(*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar<uint64_t, uint32_t>;
404+
if (dtypes.np_float4_e2m1fn.has_value()) {
405+
(*p)[dtypes.np_float4_e2m1fn->ptr()] =
406+
HandleNumpyScalar<tsl::float4_e2m1fn>;
407+
}
401408
if (dtypes.np_float8_e3m4.has_value()) {
402409
(*p)[dtypes.np_float8_e3m4->ptr()] =
403410
HandleNumpyScalar<tsl::float8_e3m4>;
@@ -595,6 +602,7 @@ absl::StatusOr<PyArgSignature> PyArgSignatureOfValue(nb::handle arg,
595602
(*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
596603
(*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
597604
// TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
605+
// (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler;
598606
// (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler;
599607
// (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler;
600608
(*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler;

xla/python/types.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ namespace {
5858

5959
struct CustomDtypes {
6060
nb_dtype bfloat16;
61+
std::optional<nb_dtype> float4_e2m1fn;
6162
std::optional<nb_dtype> float8_e3m4;
6263
std::optional<nb_dtype> float8_e4m3;
6364
nb_dtype float8_e4m3fn;
@@ -76,6 +77,10 @@ const CustomDtypes& GetCustomDtypes() {
7677
nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes");
7778
auto* dtypes = new CustomDtypes;
7879
dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16"));
80+
if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) {
81+
dtypes->float4_e2m1fn =
82+
nb_dtype::from_args(ml_dtypes.attr("float4_e2m1fn"));
83+
}
7984
if (nb::hasattr(ml_dtypes, "float8_e3m4")) {
8085
dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4"));
8186
}
@@ -147,6 +152,9 @@ absl::StatusOr<PrimitiveType> DtypeToPrimitiveType(const nb_dtype& np_type) {
147152
auto* map =
148153
new absl::flat_hash_map<nb_dtype, PrimitiveType, DtypeHash, DtypeEq>();
149154
map->emplace(custom_dtypes.bfloat16, BF16);
155+
if (custom_dtypes.float4_e2m1fn.has_value()) {
156+
map->emplace(*custom_dtypes.float4_e2m1fn, F4E2M1FN);
157+
}
150158
if (custom_dtypes.float8_e3m4.has_value()) {
151159
map->emplace(*custom_dtypes.float8_e3m4, F8E3M4);
152160
}
@@ -217,6 +225,11 @@ absl::StatusOr<nb_dtype> PrimitiveTypeToNbDtype(PrimitiveType type) {
217225
return to_nb_dtype(NPY_UINT32);
218226
case U64:
219227
return to_nb_dtype(NPY_UINT64);
228+
case F4E2M1FN:
229+
if (custom_dtypes.float4_e2m1fn.has_value()) {
230+
return *custom_dtypes.float4_e2m1fn;
231+
}
232+
break;
220233
case F8E3M4:
221234
if (custom_dtypes.float8_e3m4.has_value()) {
222235
return *custom_dtypes.float8_e3m4;
@@ -307,6 +320,11 @@ absl::StatusOr<nb_dtype> IfrtDtypeToNbDtype(ifrt::DType dtype) {
307320
return to_nb_dtype(NPY_COMPLEX64);
308321
case ifrt::DType::kC128:
309322
return to_nb_dtype(NPY_COMPLEX128);
323+
case ifrt::DType::kF4E2M1FN:
324+
if (custom_dtypes.float4_e2m1fn.has_value()) {
325+
return *custom_dtypes.float4_e2m1fn;
326+
}
327+
break;
310328
case ifrt::DType::kF8E3M4:
311329
if (custom_dtypes.float8_e3m4.has_value()) {
312330
return *custom_dtypes.float8_e3m4;
@@ -380,6 +398,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() {
380398
dtypes->np_uint32 = nb::object(numpy.attr("uint32"));
381399
dtypes->np_uint64 = nb::object(numpy.attr("uint64"));
382400
dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16"));
401+
if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) {
402+
dtypes->np_float4_e2m1fn = nb::object(ml_dtypes.attr("float4_e2m1fn"));
403+
}
383404
if (nb::hasattr(ml_dtypes, "float8_e3m4")) {
384405
dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4"));
385406
}

xla/python/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ struct NumpyScalarTypes {
8181
nanobind::object np_uint64;
8282
nanobind::object np_bfloat16;
8383
// Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0.
84+
std::optional<nanobind::object> np_float4_e2m1fn;
8485
std::optional<nanobind::object> np_float8_e3m4;
8586
std::optional<nanobind::object> np_float8_e4m3;
8687
nanobind::object np_float8_e4m3fn;

0 commit comments

Comments
 (0)