Skip to content

Commit 7f8a01a

Browse files
jeffrey-fongmaximegmd
authored andcommitted
Add safe-serialization to FullModelHFCheckpointer (meta-pytorch#1096)
1 parent 65c280f commit 7f8a01a

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

torchtune/utils/_checkpointing/_checkpointer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Any, Dict, List, Optional, Protocol
1313

1414
import torch
15+
from safetensors.torch import save_file
1516
from torchtune import utils
1617

1718
from torchtune.models import convert_weights
@@ -305,6 +306,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
305306
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
306307
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
307308
resume training from a previous run. Default is False
309+
safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`
308310
309311
Raises:
310312
ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None
@@ -319,6 +321,7 @@ def __init__(
319321
adapter_checkpoint: Optional[str] = None,
320322
recipe_checkpoint: Optional[str] = None,
321323
resume_from_checkpoint: bool = False,
324+
safe_serialization: bool = False,
322325
) -> None:
323326
self._checkpoint_dir = Path(checkpoint_dir)
324327
self._checkpoint_paths = self._validate_hf_checkpoint_files(checkpoint_files)
@@ -331,6 +334,7 @@ def __init__(
331334
self._model_type = ModelType[model_type]
332335
self._output_dir = Path(output_dir)
333336
self._resume_from_checkpoint = resume_from_checkpoint
337+
self._safe_serialization = safe_serialization
334338

335339
# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
336340
# parition the state dict into output checkpoint files. This is updated during checkpoint
@@ -508,10 +512,17 @@ def save_checkpoint(
508512

509513
# write the partitioned state dicts to the right checkpoint file
510514
for cpt_idx, model_state_dict in split_state_dicts.items():
511-
output_path = Path.joinpath(
512-
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
513-
).with_suffix(".pt")
514-
torch.save(model_state_dict, output_path)
515+
if not self._safe_serialization:
516+
output_path = Path.joinpath(
517+
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
518+
).with_suffix(".pt")
519+
torch.save(model_state_dict, output_path)
520+
else:
521+
output_path = Path.joinpath(
522+
self._output_dir,
523+
f"model-0{cpt_idx}-of-0{list(split_state_dicts.keys())[-1]}_{epoch}",
524+
).with_suffix(".safetensors")
525+
save_file(model_state_dict, output_path)
515526
logger.info(
516527
"Model checkpoint of size "
517528
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "

0 commit comments

Comments
 (0)