Skip to content

Commit 7e8726e

Browse files
authored
[MLIR] Add f8E3M4 IEEE 754 type (#101230)
This PR adds `f8E3M4` type to mlir. `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` Related PRs: - [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type - [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type
1 parent a032bfb commit 7e8726e

File tree

5 files changed

+59
-0
lines changed

5 files changed

+59
-0
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
139139
/// context.
140140
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
141141

142+
/// Returns the typeID of an Float8E3M4 type.
143+
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void);
144+
145+
/// Checks whether the given type is an f8E3M4 type.
146+
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
147+
148+
/// Creates an f8E3M4 type in the given context. The type is owned by the
149+
/// context.
150+
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
151+
142152
/// Returns the typeID of an BFloat16 type.
143153
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
144154

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,26 @@ class PyFloat8E5M2FNUZType
246246
}
247247
};
248248

249+
/// Floating Point Type subclass - Float8E3M4Type.
250+
class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
251+
public:
252+
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
253+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
254+
mlirFloat8E3M4TypeGetTypeID;
255+
static constexpr const char *pyClassName = "Float8E3M4Type";
256+
using PyConcreteType::PyConcreteType;
257+
258+
static void bindDerived(ClassTy &c) {
259+
c.def_static(
260+
"get",
261+
[](DefaultingPyMlirContext context) {
262+
MlirType t = mlirFloat8E3M4TypeGet(context->get());
263+
return PyFloat8E3M4Type(context->getRef(), t);
264+
},
265+
py::arg("context") = py::none(), "Create a float8_e3m4 type.");
266+
}
267+
};
268+
249269
/// Floating Point Type subclass - BF16Type.
250270
class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
251271
public:
@@ -864,6 +884,7 @@ void mlir::python::populateIRTypes(py::module &m) {
864884
PyFloat8E4M3FNUZType::bind(m);
865885
PyFloat8E4M3B11FNUZType::bind(m);
866886
PyFloat8E5M2FNUZType::bind(m);
887+
PyFloat8E3M4Type::bind(m);
867888
PyBF16Type::bind(m);
868889
PyF16Type::bind(m);
869890
PyTF32Type::bind(m);

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
157157
return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
158158
}
159159

160+
MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
161+
return wrap(Float8E3M4Type::getTypeID());
162+
}
163+
164+
bool mlirTypeIsAFloat8E3M4(MlirType type) {
165+
return unwrap(type).isFloat8E3M4();
166+
}
167+
168+
MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
169+
return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
170+
}
171+
160172
MlirTypeID mlirBFloat16TypeGetTypeID() {
161173
return wrap(BFloat16Type::getTypeID());
162174
}

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ __all__ = [
120120
"F32Type",
121121
"F64Type",
122122
"FlatSymbolRefAttr",
123+
"Float8E3M4Type",
123124
"Float8E4M3B11FNUZType",
124125
"Float8E4M3FNType",
125126
"Float8E4M3FNUZType",
@@ -1537,6 +1538,19 @@ class FlatSymbolRefAttr(Attribute):
15371538
Returns the value of the FlatSymbolRef attribute as a string
15381539
"""
15391540

1541+
class Float8E3M4Type(FloatType):
1542+
static_typeid: ClassVar[TypeID]
1543+
@staticmethod
1544+
def get(context: Optional[Context] = None) -> Float8E3M4Type:
1545+
"""
1546+
Create a float8_e3m4 type.
1547+
"""
1548+
@staticmethod
1549+
def isinstance(other: Type) -> bool: ...
1550+
def __init__(self, cast_from_type: Type) -> None: ...
1551+
@property
1552+
def typeid(self) -> TypeID: ...
1553+
15401554
class Float8E4M3B11FNUZType(FloatType):
15411555
static_typeid: ClassVar[TypeID]
15421556
@staticmethod

mlir/python/mlir/extras/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
F16Type,
1313
F32Type,
1414
F64Type,
15+
Float8E3M4Type,
1516
Float8E4M3B11FNUZType,
1617
Float8E4M3FNType,
1718
Float8E4M3Type,
@@ -72,6 +73,7 @@ def ui(width):
7273
f8E4M3 = lambda: Float8E4M3Type.get()
7374
f8E4M3FN = lambda: Float8E4M3FNType.get()
7475
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
76+
f8E3M4 = lambda: Float8E3M4Type.get()
7577

7678
none = lambda: NoneType.get()
7779

0 commit comments

Comments
 (0)