1212from typing import Any , Dict , List , Optional , Protocol
1313
1414import torch
15+ from safetensors .torch import save_file
1516from torchtune import utils
1617
1718from torchtune .models import convert_weights
@@ -305,6 +306,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
305306 recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
306307 resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
307308 resume training from a previous run. Default is False
309+ safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`
308310
309311 Raises:
310312 ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None
@@ -319,6 +321,7 @@ def __init__(
319321 adapter_checkpoint : Optional [str ] = None ,
320322 recipe_checkpoint : Optional [str ] = None ,
321323 resume_from_checkpoint : bool = False ,
324+ safe_serialization : bool = False ,
322325 ) -> None :
323326 self ._checkpoint_dir = Path (checkpoint_dir )
324327 self ._checkpoint_paths = self ._validate_hf_checkpoint_files (checkpoint_files )
@@ -331,6 +334,7 @@ def __init__(
331334 self ._model_type = ModelType [model_type ]
332335 self ._output_dir = Path (output_dir )
333336 self ._resume_from_checkpoint = resume_from_checkpoint
337+ self ._safe_serialization = safe_serialization
334338
335339 # weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
336340 # parition the state dict into output checkpoint files. This is updated during checkpoint
@@ -508,10 +512,17 @@ def save_checkpoint(
508512
509513 # write the partitioned state dicts to the right checkpoint file
510514 for cpt_idx , model_state_dict in split_state_dicts .items ():
511- output_path = Path .joinpath (
512- self ._output_dir , f"hf_model_{ cpt_idx } _{ epoch } "
513- ).with_suffix (".pt" )
514- torch .save (model_state_dict , output_path )
515+ if not self ._safe_serialization :
516+ output_path = Path .joinpath (
517+ self ._output_dir , f"hf_model_{ cpt_idx } _{ epoch } "
518+ ).with_suffix (".pt" )
519+ torch .save (model_state_dict , output_path )
520+ else :
521+ output_path = Path .joinpath (
522+ self ._output_dir ,
523+ f"model-0{ cpt_idx } -of-0{ list (split_state_dicts .keys ())[- 1 ]} _{ epoch } " ,
524+ ).with_suffix (".safetensors" )
525+ save_file (model_state_dict , output_path )
515526 logger .info (
516527 "Model checkpoint of size "
517528 f"{ os .path .getsize (output_path ) / 1000 ** 3 :.2f} GB "
0 commit comments