Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 90 additions & 14 deletions bindings/python/py_src/safetensors/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ def deserialize(bytes):
Opens a safetensors lazily and returns tensors as asked

Args:
data (:obj:`bytes`):
data (`bytes`):
The byte content of a file

Returns:
(:obj:`List[str, Dict[str, Dict[str, any]]]`):
(`List[str, Dict[str, Dict[str, any]]]`):
The deserialized content is like:
[("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)]
"""
Expand All @@ -21,14 +21,14 @@ def serialize(tensor_dict, metadata=None):
Serializes raw data.

Args:
tensor_dict (:obj:`Dict[str, Dict[Any]]`):
tensor_dict (`Dict[str, Dict[Any]]`):
The tensor dict is like:
{"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}}
metadata (:obj:`Dict[str, str]`, *optional*):
metadata (`Dict[str, str]`, *optional*):
The optional purely text annotations

Returns:
(:obj:`bytes`):
(`bytes`):
The serialized content.
"""
pass
Expand All @@ -39,16 +39,16 @@ def serialize_file(tensor_dict, filename, metadata=None):
Serializes raw data.

Args:
tensor_dict (:obj:`Dict[str, Dict[Any]]`):
tensor_dict (`Dict[str, Dict[Any]]`):
The tensor dict is like:
{"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}}
filename (:obj:`str`):
filename (`str`, or `os.PathLike`):
The name of the file to write into.
metadata (:obj:`Dict[str, str]`, *optional*):
metadata (`Dict[str, str]`, *optional*):
The optional purely text annotations

Returns:
(:obj:`bytes`):
(`bytes`):
The serialized content.
"""
pass
Expand All @@ -58,16 +58,92 @@ class safe_open:
Opens a safetensors lazily and returns tensors as asked

Args:
filename (:obj:`str`):
filename (`str`, or `os.PathLike`):
The filename to open

framework (:obj:`str`):
The framework you want your tensors in. Supported values:
framework (`str`):
The framework you want you tensors in. Supported values:
`pt`, `tf`, `flax`, `numpy`.

device (:obj:`str`, defaults to :obj:`"cpu"`):
device (`str`, defaults to `"cpu"`):
The device on which you want the tensors.
"""

def __init__(self, filename, framework, device="cpu"):
def __init__(filename, framework, device=...):
pass
def __enter__(self):
"""
Start the context manager
"""
pass
def __exit__(self, _exc_type, _exc_value, _traceback):
"""
Exits the context manager
"""
pass
def get_slice(self, name):
"""
Returns a full slice view object

Args:
name (`str`):
The name of the tensor you want

Returns:
(`PySafeSlice`):
A dummy object you can slice into to get a real tensor
Example:
```python
from safetensors import safe_open

with safe_open("model.safetensors", framework="pt", device=0) as f:
tensor_part = f.get_slice("embedding")[:, ::8]

```
"""
pass
def get_tensor(self, name):
"""
Returns a full tensor

Args:
name (`str`):
The name of the tensor you want

Returns:
(`Tensor`):
The tensor in the framework you opened the file for.

Example:
```python
from safetensors import safe_open

with safe_open("model.safetensors", framework="pt", device=0) as f:
tensor = f.get_tensor("embedding")

```
"""
pass
def keys(self):
"""
Returns the names of the tensors in the file.

Returns:
(`List[str]`):
The name of the tensors contained in that file
"""
pass
def metadata(self):
"""
Return the special non tensor information in the header

Returns:
(`Dict[str, str]`):
The freeform metadata.
"""
pass

class SafetensorError(Exception):
"""
Custom Python Exception for Safetensor errors.
"""
2 changes: 2 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,10 +755,12 @@ impl safe_open {
self.inner()?.get_slice(name)
}

/// Start the context manager
pub fn __enter__(slf: Py<Self>) -> Py<Self> {
slf
}

/// Exits the context manager
pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) {
self.inner = None;
}
Expand Down
37 changes: 30 additions & 7 deletions bindings/python/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ def member_sort(member):
def fn_predicate(obj):
value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
if value:
return obj.__doc__ and obj.__text_signature__ and not obj.__name__.startswith("_")
return (
obj.__doc__
and obj.__text_signature__
and (
not obj.__name__.startswith("_")
or obj.__name__ in {"__enter__", "__exit__"}
)
)
if inspect.isgetsetdescriptor(obj):
return obj.__doc__ and not obj.__name__.startswith("_")
return False
Expand Down Expand Up @@ -74,7 +81,9 @@ def pyi_file(obj, indent=""):

body = ""
if obj.__doc__:
body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
body += (
f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
)

fns = inspect.getmembers(obj, fn_predicate)

Expand All @@ -84,7 +93,7 @@ def pyi_file(obj, indent=""):
body += f"{indent+INDENT}pass\n"
body += "\n"

for (name, fn) in fns:
for name, fn in fns:
body += pyi_file(fn, indent=indent)

if not body:
Expand Down Expand Up @@ -130,13 +139,18 @@ def do_black(content, is_pyi):
experimental_string_processing=False,
)
try:
content = content.replace("$self", "self")
return black.format_file_contents(content, fast=True, mode=mode)
except black.NothingChanged:
return content


def write(module, directory, origin, check=False):
submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
submodules = [
(name, member)
for name, member in inspect.getmembers(module)
if inspect.ismodule(member)
]

filename = os.path.join(directory, "__init__.pyi")
pyi_content = pyi_file(module)
Expand All @@ -145,7 +159,9 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
assert (
data == pyi_content
), f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(pyi_content)
Expand All @@ -168,7 +184,9 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
assert (
data == py_content
), f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(py_content)
Expand All @@ -184,4 +202,9 @@ def write(module, directory, origin, check=False):
args = parser.parse_args()
import safetensors

write(safetensors.safetensors_rust, "py_src/safetensors/", "safetensors", check=args.check)
write(
safetensors._safetensors_rust,
"py_src/safetensors/",
"safetensors",
check=args.check,
)
Loading