Skip to content

Commit 9455a2b

Browse files
committed
Fixing the version check for uint support in torch.
1 parent 5eecd29 commit 9455a2b

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ repos:
3131
"--",
3232
"-Dwarnings",
3333
]
34-
- repo: https://github.com/psf/black
35-
rev: 22.3.0
34+
- repo: https://github.com/astral-sh/ruff-pre-commit
35+
# Ruff version.
36+
rev: v0.12.8
3637
hooks:
37-
- id: black
38-
name: "Python (black)"
39-
args: ["--line-length", "119", "--target-version", "py35"]
40-
types: ["python"]
38+
# Run the linter.
39+
- id: ruff-check
40+
# Run the formatter.
41+
- id: ruff-format
4142
- repo: https://github.com/astral-sh/ruff-pre-commit
4243
# Ruff version.
4344
rev: v0.11.11

bindings/python/py_src/safetensors/torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def load(data: bytes) -> Dict[str, torch.Tensor]:
433433
_float8_e8m0: 1,
434434
_float4_e2m1_x2: 1,
435435
}
436-
if Version(torch.__version__) > Version("2.0.0"):
436+
if Version(torch.__version__) >= Version("2.3.0"):
437437
_SIZE.update(
438438
{
439439
torch.uint64: 8,
@@ -456,7 +456,7 @@ def load(data: bytes) -> Dict[str, torch.Tensor]:
456456
"F8_E4M3": _float8_e4m3fn,
457457
"F8_E5M2": _float8_e5m2,
458458
}
459-
if Version(torch.__version__) > Version("2.0.0"):
459+
if Version(torch.__version__) >= Version("2.3.0"):
460460
_TYPES.update(
461461
{
462462
"U64": torch.uint64,

0 commit comments

Comments
 (0)