Skip to content

Commit 634500b

Browse files
Fix FSDP2 and distributed checkpointing imports for older PyTorch versions (huggingface#46141)
* Fix PyTorch requirement for FSDP2 to >=2.6 * Fix PyTorch requirement for distributed checkpoint saving to >=2.7 --------- Co-authored-by: Ferdinand Mom <47445085+3outeille@users.noreply.github.com>
1 parent 5a4b70f commit 634500b

2 files changed

Lines changed: 14 additions & 6 deletions

File tree

src/transformers/distributed/fsdp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
if is_torch_available():
2929
import torch
3030

31-
if is_torch_available() and is_torch_greater_or_equal("2.5"):
31+
if is_torch_available() and is_torch_greater_or_equal("2.6"):
3232
import torch.distributed as dist
3333
from torch.distributed._composable.fsdp import fully_shard
3434
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
@@ -91,8 +91,8 @@ def initialize_fsdp(
9191
if fsdp_plan is None:
9292
return device_map, device_mesh, None
9393

94-
if not is_torch_greater_or_equal("2.5"):
95-
raise OSError("FSDP2 is only supported for `torch>=2.5`.")
94+
if not is_torch_greater_or_equal("2.6"):
95+
raise OSError("FSDP2 is only supported for `torch>=2.6`.")
9696

9797
if device_mesh is None:
9898
# Detect the accelerator on the machine
@@ -338,8 +338,8 @@ def apply_fully_shard_data_parallel(
338338
if not is_torch_available():
339339
raise ImportError("PyTorch is required for FSDP support")
340340

341-
if not is_torch_greater_or_equal("2.5"):
342-
raise OSError("FSDP2 requires torch>=2.5")
341+
if not is_torch_greater_or_equal("2.6"):
342+
raise OSError("FSDP2 requires torch>=2.6")
343343

344344
if fsdp_plan is None:
345345
fsdp_plan = {}

src/transformers/distributed/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@
3939
if is_torch_available():
4040
import torch
4141
import torch.distributed.checkpoint as dcp
42-
from torch.distributed.checkpoint.hf_storage import HuggingFaceStorageWriter
4342
from torch.distributed.checkpoint.state_dict import (
4443
get_model_state_dict,
4544
get_optimizer_state_dict,
4645
set_optimizer_state_dict,
4746
)
4847
from torch.distributed.tensor import DTensor
4948

49+
if is_torch_greater_or_equal("2.7"):
50+
from torch.distributed.checkpoint.hf_storage import HuggingFaceStorageWriter
51+
5052

5153
def _ensure_torch_distributed(device_type: str):
5254
"""Initialize torch.distributed if not already initialized."""
@@ -103,6 +105,9 @@ def init_device_mesh(distributed_config: DistributedConfig) -> torch.distributed
103105
if not is_torch_greater_or_equal("2.5"):
104106
raise OSError("Distributed training with DistributedConfig requires `torch>=2.5`.")
105107

108+
if distributed_config.fsdp_size > 1 and not is_torch_greater_or_equal("2.6"):
109+
raise OSError("FSDP2 requires `torch>=2.6`.")
110+
106111
device_type = torch._C._get_accelerator().type
107112
_ensure_torch_distributed(device_type)
108113

@@ -205,6 +210,9 @@ def save_model_checkpoint_distributed(model, checkpoint_dir: str) -> None:
205210
gate||up MoE weights) are replicated to a full tensor on every rank
206211
before the save, otherwise DCP cannot encode that placement.
207212
"""
213+
if not is_torch_greater_or_equal("2.7"):
214+
raise OSError("Distributed checkpoint saving requires `torch>=2.7`.")
215+
208216
state_dict = get_model_state_dict(model)
209217
for key, value in list(state_dict.items()):
210218
if (

0 commit comments

Comments
 (0)