diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 9fbbf2a7..f132901f 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -12,7 +12,7 @@ jobs: os: [ubuntu-latest, macos-13, windows-latest] # Lowest and highest, no version specified so that # new releases get automatically tested against - version: [{torch: torch==1.10, python: "3.8"}, {torch: torch, python: "3.12"}] + version: [{torch: torch==1.10, python: "3.8", arch: "x64"}, {torch: torch, python: "3.12", arch: "x64"}] # TODO this would include macos ARM target. # however jax has an illegal instruction issue # that exists only in CI (probably difference in instruction support). @@ -21,6 +21,12 @@ jobs: # version: # torch: torch # python: "3.11" + include: + - os: ubuntu-latest + version: + torch: torch + python: "3.13" + arch: "x64-freethreaded" defaults: run: working-directory: ./bindings/python @@ -46,7 +52,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.version.python }} - architecture: "x64" + architecture: ${{ matrix.version.arch }} - name: Lint with RustFmt run: cargo fmt -- --check @@ -60,12 +66,29 @@ jobs: - name: Install run: | pip install -U pip - pip install .[numpy,tensorflow] + pip install .[numpy] + + - name: Install (torch) + if: matrix.version.arch != 'x64-freethreaded' + run: | pip install ${{ matrix.version.torch }} + shell: bash - - name: Install (jax, flax) - if: matrix.os != 'windows-latest' + - name: Install (torch freethreaded) + if: matrix.version.arch == 'x64-freethreaded' + run: | + pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cu126 + shell: bash + + - name: Install (tensorflow) + if: matrix.version.arch != 'x64-freethreaded' run: | + pip install .[tensorflow] + shell: bash + + - name: Install (jax, flax) + if: matrix.os != 'windows-latest' && matrix.version.arch != "x64-freethreaded" + run: pip install .[jax] shell: bash diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index de7692bc..301cc9e7 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -10,7 +10,7 @@ name = "safetensors_rust" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.23", features = ["abi3", "abi3-py38"] } +pyo3 = { version = "0.24", features = ["abi3", "abi3-py38"] } memmap2 = "0.9" serde_json = "1.0" diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index a65b1aab..e16d7fdd 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -3,7 +3,7 @@ use memmap2::{Mmap, MmapOptions}; use pyo3::exceptions::{PyException, PyFileNotFoundError}; use pyo3::prelude::*; -use pyo3::sync::GILOnceCell; +use pyo3::sync::OnceLockExt; use pyo3::types::IntoPyDict; use pyo3::types::{PyBool, PyByteArray, PyBytes, PyDict, PyList, PySlice}; use pyo3::Bound as PyBound; @@ -18,12 +18,13 @@ use std::iter::FromIterator; use std::ops::Bound; use std::path::PathBuf; use std::sync::Arc; +use std::sync::OnceLock; -static TORCH_MODULE: GILOnceCell> = GILOnceCell::new(); -static NUMPY_MODULE: GILOnceCell> = GILOnceCell::new(); -static TENSORFLOW_MODULE: GILOnceCell> = GILOnceCell::new(); -static FLAX_MODULE: GILOnceCell> = GILOnceCell::new(); -static MLX_MODULE: GILOnceCell> = GILOnceCell::new(); +static TORCH_MODULE: OnceLock> = OnceLock::new(); +static NUMPY_MODULE: OnceLock> = OnceLock::new(); +static TENSORFLOW_MODULE: OnceLock> = OnceLock::new(); +static FLAX_MODULE: OnceLock> = OnceLock::new(); +static MLX_MODULE: OnceLock> = OnceLock::new(); struct PyView<'a> { shape: Vec, @@ -342,7 +343,7 @@ enum Storage { /// This allows us to not manage it /// so Pytorch can handle the whole lifecycle. /// https://pytorch.org/docs/stable/storage.html#torch.TypedStorage.from_file. - TorchStorage(GILOnceCell), + TorchStorage(OnceLock), } #[derive(Debug, PartialEq, Eq, PartialOrd)] @@ -422,11 +423,11 @@ impl Open { match framework { Framework::Pytorch => { let module = PyModule::import(py, intern!(py, "torch"))?; - TORCH_MODULE.get_or_init(py, || module.into()) + TORCH_MODULE.get_or_init_py_attached(py, || module.into()) } _ => { let module = PyModule::import(py, intern!(py, "numpy"))?; - NUMPY_MODULE.get_or_init(py, || module.into()) + NUMPY_MODULE.get_or_init_py_attached(py, || module.into()) } }; @@ -444,7 +445,13 @@ impl Open { // Same for torch.asarray which is necessary for zero-copy tensor if version >= Version::new(1, 11, 0) { // storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped() - let py_filename: PyObject = filename.into_pyobject(py)?.into(); + let py_filename: PyObject = filename + .to_str() + .ok_or_else(|| { + SafetensorError::new_err(format!("Path {filename:?} is not a string")) + })? + .into_pyobject(py)? + .into(); let size: PyObject = buffer.len().into_pyobject(py)?.into(); let shared: PyObject = PyBool::new(py, false).to_owned().into(); let (size_name, storage_name) = if version >= Version::new(2, 0, 0) { @@ -466,8 +473,8 @@ impl Open { Err(_) => storage.getattr(intern!(py, "_untyped"))?, }; let storage = untyped.call0()?.into_pyobject(py)?.into(); - let gil_storage = GILOnceCell::new(); - gil_storage.get_or_init(py, || storage); + let gil_storage = OnceLock::new(); + gil_storage.get_or_init_py_attached(py, || storage); Ok(Storage::TorchStorage(gil_storage)) } else { @@ -579,7 +586,7 @@ impl Open { let stop = (info.data_offsets.1 + self.offset) as isize; let slice = PySlice::new(py, start, stop, 1); let storage: &PyObject = storage - .get(py) + .get() .ok_or_else(|| SafetensorError::new_err("Could not find storage"))?; let storage: &PyBound = storage.bind(py); let storage_slice = storage @@ -954,7 +961,7 @@ impl PySafeSlice { let stop = (self.info.data_offsets.1 + self.offset) as isize; let slice = PySlice::new(py, start, stop, 1); let storage: &PyObject = storage - .get(py) + .get() .ok_or_else(|| SafetensorError::new_err("Could not find storage"))?; let storage: &PyBound<'_, PyAny> = storage.bind(py); @@ -1025,10 +1032,10 @@ impl PySafeSlice { fn get_module<'a>( py: Python<'a>, - cell: &'static GILOnceCell>, + cell: &'static OnceLock>, ) -> PyResult<&'a PyBound<'a, PyModule>> { let module: &PyBound<'a, PyModule> = cell - .get(py) + .get() .ok_or_else(|| SafetensorError::new_err("Could not find module"))? .bind(py); Ok(module) @@ -1045,7 +1052,7 @@ fn create_tensor<'a>( let (module, is_numpy): (&PyBound<'_, PyModule>, bool) = match framework { Framework::Pytorch => ( TORCH_MODULE - .get(py) + .get() .ok_or_else(|| { SafetensorError::new_err(format!("Could not find module {framework:?}",)) })? @@ -1054,7 +1061,7 @@ fn create_tensor<'a>( ), _ => ( NUMPY_MODULE - .get(py) + .get() .ok_or_else(|| { SafetensorError::new_err(format!("Could not find module {framework:?}",)) })? @@ -1097,7 +1104,7 @@ fn create_tensor<'a>( Framework::Flax => { let module = Python::with_gil(|py| -> PyResult<&Py> { let module = PyModule::import(py, intern!(py, "jax"))?; - Ok(FLAX_MODULE.get_or_init(py, || module.into())) + Ok(FLAX_MODULE.get_or_init_py_attached(py, || module.into())) })? .bind(py); module @@ -1108,7 +1115,7 @@ fn create_tensor<'a>( Framework::Tensorflow => { let module = Python::with_gil(|py| -> PyResult<&Py> { let module = PyModule::import(py, intern!(py, "tensorflow"))?; - Ok(TENSORFLOW_MODULE.get_or_init(py, || module.into())) + Ok(TENSORFLOW_MODULE.get_or_init_py_attached(py, || module.into())) })? .bind(py); module @@ -1118,7 +1125,7 @@ fn create_tensor<'a>( Framework::Mlx => { let module = Python::with_gil(|py| -> PyResult<&Py> { let module = PyModule::import(py, intern!(py, "mlx"))?; - Ok(MLX_MODULE.get_or_init(py, || module.into())) + Ok(MLX_MODULE.get_or_init_py_attached(py, || module.into())) })? .bind(py); module @@ -1192,7 +1199,7 @@ pyo3::create_exception!( ); /// A Python module implemented in Rust. -#[pymodule] +#[pymodule(gil_used = false)] fn _safetensors_rust(m: &PyBound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(serialize, m)?)?; m.add_function(wrap_pyfunction!(serialize_file, m)?)?; diff --git a/bindings/python/tests/test_simple.py b/bindings/python/tests/test_simple.py index 8e840138..d198391f 100644 --- a/bindings/python/tests/test_simple.py +++ b/bindings/python/tests/test_simple.py @@ -1,5 +1,6 @@ import os import tempfile +import threading import unittest from pathlib import Path @@ -117,9 +118,10 @@ def test_accept_path(self): "a": torch.zeros((2, 2)), "b": torch.zeros((2, 3), dtype=torch.uint8), } - save_file_pt(tensors, Path("./out.safetensors")) - load_file_pt(Path("./out.safetensors")) - os.remove(Path("./out.safetensors")) + filename = f"./out_{threading.get_ident()}.safetensors" + save_file_pt(tensors, Path(filename)) + load_file_pt(Path(filename)) + os.remove(Path(filename)) def test_pt_sf_save_model_overlapping_storage(self): m = torch.randn(10) @@ -157,14 +159,14 @@ def test_get_correctly_dropped(self): "a": torch.zeros((2, 2)), "b": torch.zeros((2, 3), dtype=torch.uint8), } - save_file_pt(tensors, "./out.safetensors") - with safe_open("./out.safetensors", framework="pt") as f: + save_file_pt(tensors, "./out_windows.safetensors") + with safe_open("./out_windows.safetensors", framework="pt") as f: pass with self.assertRaises(SafetensorError): print(f.keys()) - with open("./out.safetensors", "w") as g: + with open("./out_windows.safetensors", "w") as g: g.write("something") @@ -188,11 +190,11 @@ def assertTensorEqual(self, tensors1, tensors2, equality_fn): def test_numpy_example(self): tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)} - save_file(tensors, "./out.safetensors") + save_file(tensors, "./out_np.safetensors") out = save(tensors) # Now loading - loaded = load_file("./out.safetensors") + loaded = load_file("./out_np.safetensors") self.assertTensorEqual(tensors, loaded, np.allclose) loaded = load(out) @@ -220,10 +222,11 @@ def test_torch_example(self): # test to be correct. tensors2 = tensors.copy() - save_file_pt(tensors, "./out.safetensors") + filename = f"./out_pt_{threading.get_ident()}.safetensors" + save_file_pt(tensors, filename) # Now loading - loaded = load_file_pt("./out.safetensors") + loaded = load_file_pt(filename) self.assertTensorEqual(tensors2, loaded, torch.allclose) def test_exception(self): @@ -237,10 +240,11 @@ def test_torch_slice(self): tensors = { "a": A, } - save_file_pt(tensors, "./slice.safetensors") + ident = threading.get_ident() + save_file_pt(tensors, f"./slice_{ident}.safetensors") # Now loading - with safe_open("./slice.safetensors", framework="pt", device="cpu") as f: + with safe_open(f"./slice_{ident}.safetensors", framework="pt", device="cpu") as f: slice_ = f.get_slice("a") tensor = slice_[:] self.assertEqual(list(tensor.shape), [10, 5]) @@ -283,10 +287,11 @@ def test_numpy_slice(self): tensors = { "a": A, } - save_file(tensors, "./slice.safetensors") + filename = f"./slice_{threading.get_ident()}.safetensors" + save_file(tensors, filename) # Now loading - with safe_open("./slice.safetensors", framework="np", device="cpu") as f: + with safe_open(filename, framework="np", device="cpu") as f: slice_ = f.get_slice("a") tensor = slice_[:] self.assertEqual(list(tensor.shape), [10, 5])