Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Added new 8-bit float types following IEEE 754 convention:
`ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`.
* Added new 4-bit and 6-bit float types:
`ml_dtypes.float4_e2m1fn`, `ml_dtypes.float6_e2m3fn` and `ml_dtypes.float6_e3m2fn`.
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.

## [0.4.0] - 2024-04-1
Expand Down
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
* `float8_e4m3fnuz`
* `float8_e5m2`
* `float8_e5m2fnuz`
- Microscaling (MX) sub-byte floating point representations including:
* `float4_e2m1fn`
* `float6_e2m3fn`
* `float6_e3m2fn`
- `int2`, `int4`, `uint2` and `uint4`: low precision integer types.

See below for specifications of these number formats.
Expand Down Expand Up @@ -66,6 +70,39 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

### `float4_e2m1fn`

Exponent: 2, Mantissa: 1, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4
bits are unused). NaN representation is undefined.

Possible absolute values: [`0`, `0.5`, `1`, `1.5`, `2`, `3`, `4`, `6`]

### `float6_e2m3fn`

Exponent: 2, Mantissa: 3, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [`-7.5`; `7.5`]

### `float6_e3m2fn`

Exponent: 3, Mantissa: 2, bias: 3.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [`-28`; `28`]

### `float8_e3m4`

Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.
Expand Down
9 changes: 9 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"__version__",
"bfloat16",
"finfo",
"float4_e2m1fn",
"float6_e2m3fn",
"float6_e3m2fn",
"float8_e3m4",
"float8_e4m3",
"float8_e4m3b11fnuz",
Expand All @@ -36,6 +39,9 @@
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo
from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float4_e2m1fn
from ml_dtypes._ml_dtypes_ext import float6_e2m3fn
from ml_dtypes._ml_dtypes_ext import float6_e3m2fn
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
Expand All @@ -50,6 +56,9 @@
import numpy as np

bfloat16: Type[np.generic]
float4_e2m1fn: Type[np.generic]
float6_e2m3fn: Type[np.generic]
float6_e3m2fn: Type[np.generic]
float8_e3m4: Type[np.generic]
float8_e4m3: Type[np.generic]
float8_e4m3b11fnuz: Type[np.generic]
Expand Down
254 changes: 187 additions & 67 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from typing import Dict

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float4_e2m1fn
from ml_dtypes._ml_dtypes_ext import float6_e2m3fn
from ml_dtypes._ml_dtypes_ext import float6_e3m2fn
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
Expand All @@ -27,6 +30,9 @@
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
_float6_e2m3fn_dtype = np.dtype(float6_e2m3fn)
_float6_e3m2fn_dtype = np.dtype(float6_e3m2fn)
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
Expand All @@ -45,6 +51,33 @@ def __init__(self):
self.smallest_subnormal = bfloat16(smallest_subnormal)


class _Float4E2m1fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p0")
self.smallest_normal = float4_e2m1fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.8p0")
self.smallest_subnormal = float4_e2m1fn(smallest_subnormal)


class _Float6E2m3fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p0")
self.smallest_normal = float6_e2m3fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.2p0")
self.smallest_subnormal = float6_e2m3fn(smallest_subnormal)


class _Float6E3m2fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-2")
self.smallest_normal = float6_e3m2fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.4p-2")
self.smallest_subnormal = float6_e3m2fn(smallest_subnormal)


class _Float8E3m4MachArLike:

def __init__(self):
Expand Down Expand Up @@ -110,7 +143,7 @@ def __init__(self):

class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
_finfo_cache: Dict[type, np.finfo] = {}

@staticmethod
def _bfloat16_finfo():
Expand Down Expand Up @@ -157,6 +190,129 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float4_e2m1fn_finfo():
eps = float.fromhex("0x0.8p0") # 0.5
max_ = float.fromhex("0x1.8p2") # 6.0

obj = object.__new__(np.finfo)
obj.dtype = _float4_e2m1fn_dtype
obj.bits = 4
obj.eps = eps
obj.epsneg = eps
obj.machep = -1
obj.negep = -1
obj.max = float4_e2m1fn(max_)
obj.min = float4_e2m1fn(-max_)
obj.nexp = 2
obj.nmant = 1
obj.iexp = obj.nexp
obj.maxexp = 3
obj.minexp = 0
obj.precision = 0
obj.resolution = float4_e2m1fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float4E2m1fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float6_e2m3fn_finfo():
eps = float.fromhex("0x0.2p0") # 0.125
max_ = float.fromhex("0x1.Ep2") # 7.5

obj = object.__new__(np.finfo)
obj.dtype = _float6_e2m3fn_dtype
obj.bits = 6
obj.eps = eps
obj.epsneg = eps
obj.machep = -3
obj.negep = -3
obj.max = float6_e2m3fn(max_)
obj.min = float6_e2m3fn(-max_)
obj.nexp = 2
obj.nmant = 3
obj.iexp = obj.nexp
obj.maxexp = 3
obj.minexp = 0
obj.precision = 0
obj.resolution = float6_e2m3fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float6E2m3fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float6_e3m2fn_finfo():
eps = float.fromhex("0x1p-2") # 0.25
max_ = float.fromhex("0x1.Cp4") # 28

obj = object.__new__(np.finfo)
obj.dtype = _float6_e3m2fn_dtype
obj.bits = 6
obj.eps = eps
obj.epsneg = eps / 2
obj.machep = -2
obj.negep = -3
obj.max = float6_e3m2fn(max_)
obj.min = float6_e3m2fn(-max_)
obj.nexp = 3
obj.nmant = 2
obj.iexp = obj.nexp
obj.maxexp = 5
obj.minexp = -2
obj.precision = 0
obj.resolution = float6_e3m2fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float6E3m2fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e3m4_finfo():
def float_to_str(f):
Expand Down Expand Up @@ -472,71 +628,35 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

_finfo_type_map = {
bfloat16: _bfloat16_finfo,
float4_e2m1fn: _float4_e2m1fn_finfo,
float6_e2m3fn: _float6_e2m3fn_finfo,
float6_e3m2fn: _float6_e3m2fn_finfo,
float8_e3m4: _float8_e3m4_finfo,
float8_e4m3: _float8_e4m3_finfo,
float8_e4m3fn: _float8_e4m3fn_finfo,
float8_e4m3fnuz: _float8_e4m3fnuz_finfo,
float8_e4m3b11fnuz: _float8_e4m3b11fnuz_finfo,
float8_e5m2: _float8_e5m2_finfo,
float8_e5m2fnuz: _float8_e5m2fnuz_finfo,
}
_finfo_name_map = {t.__name__: t for t in _finfo_type_map}

def __new__(cls, dtype):
if (
isinstance(dtype, str)
and dtype == "bfloat16"
or dtype == _bfloat16_dtype
):
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e3m4"
or dtype == _float8_e3m4_dtype
):
if _float8_e3m4_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e3m4_dtype] = cls._float8_e3m4_finfo()
return cls._finfo_cache[_float8_e3m4_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3"
or dtype == _float8_e4m3_dtype
):
if _float8_e4m3_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo()
return cls._finfo_cache[_float8_e4m3_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3b11fnuz"
or dtype == _float8_e4m3b11fnuz_dtype
):
if _float8_e4m3b11fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3b11fnuz_dtype] = (
cls._float8_e4m3b11fnuz_finfo()
)
return cls._finfo_cache[_float8_e4m3b11fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fn"
or dtype == _float8_e4m3fn_dtype
):
if _float8_e4m3fn_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo()
return cls._finfo_cache[_float8_e4m3fn_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fnuz"
or dtype == _float8_e4m3fnuz_dtype
):
if _float8_e4m3fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fnuz_dtype] = cls._float8_e4m3fnuz_finfo()
return cls._finfo_cache[_float8_e4m3fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2"
or dtype == _float8_e5m2_dtype
):
if _float8_e5m2_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo()
return cls._finfo_cache[_float8_e5m2_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2fnuz"
or dtype == _float8_e5m2fnuz_dtype
):
if _float8_e5m2fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo()
return cls._finfo_cache[_float8_e5m2fnuz_dtype]
key = (
cls._finfo_name_map.get(dtype)
if isinstance(dtype, str)
else dtype.type
if isinstance(dtype, np.dtype)
else dtype
)
finfo = cls._finfo_cache.get(key)
if finfo is not None:
return finfo

init = cls._finfo_type_map.get(key)
if init is not None:
cls._finfo_cache[dtype] = init.__func__()
return cls._finfo_cache[dtype]
return super().__new__(cls, dtype)
Loading
Loading