2626from torchtune .recipe_interfaces import FTRecipeInterface
2727from torchtune .training import DummyProfiler , PROFILER_KEY
2828from torchtune .training .activations import apply_selective_activation_checkpointing
29+ from torchtune .training .checkpointing ._checkpoint_client import (
30+ CheckpointClient ,
31+ TrainingProgress ,
32+ )
2933from torchtune .training .lr_schedulers import get_lr
3034
3135from tqdm import tqdm
@@ -138,9 +142,11 @@ def __init__(self, cfg: DictConfig) -> None:
138142
139143 # Training cfg
140144 self ._resume_from_checkpoint = cfg .resume_from_checkpoint
145+ self ._enable_async_checkpointing = cfg .get ("enable_async_checkpointing" , False )
141146 self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
142147 self ._optimizer_in_bwd = cfg .get ("optimizer_in_bwd" , False )
143148 self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
149+ self ._checkpoint_client = CheckpointClient (cfg )
144150
145151 # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
146152 if self ._optimizer_in_bwd :
@@ -189,21 +195,6 @@ def __init__(self, cfg: DictConfig) -> None:
189195 self .max_steps_per_epoch = cfg .max_steps_per_epoch
190196 self .global_step = 0
191197
192- def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
193- """
194- Extract the checkpoint state from file and validate. If resume_from_checkpoint
195- is True, this also includes the recipe state.
196- """
197- self ._checkpointer = config .instantiate (
198- cfg_checkpointer ,
199- resume_from_checkpoint = self ._resume_from_checkpoint ,
200- )
201- checkpoint_dict = self ._checkpointer .load_checkpoint ()
202-
203- if self ._resume_from_checkpoint :
204- self ._update_recipe_state (checkpoint_dict )
205- return checkpoint_dict
206-
207198 def _update_recipe_state (self , ckpt_dict : Dict [str , Any ]) -> None :
208199 """
209200 Updates the recipe state from checkpoint.
@@ -255,7 +246,8 @@ def setup(self, cfg: DictConfig) -> None:
255246 # log config with parameter override
256247 self ._metric_logger .log_config (cfg )
257248
258- checkpoint_dict = self .load_checkpoint (cfg_checkpointer = cfg .checkpointer )
249+ # Load the base model
250+ checkpoint_dict = self ._checkpoint_client .load_base_checkpoint ()
259251
260252 self ._compile = cfg .get ("compile" , False )
261253 self ._model = self ._setup_model (
@@ -276,11 +268,36 @@ def setup(self, cfg: DictConfig) -> None:
276268 optimizer_in_bwd = self ._optimizer_in_bwd ,
277269 opt_state_dict = (
278270 checkpoint_dict [training .OPT_KEY ]
279- if self . _resume_from_checkpoint
271+ if training . OPT_KEY in checkpoint_dict
280272 else None
281273 ),
282274 )
283275
276+ if self ._resume_from_checkpoint :
277+ # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
278+ # using the DistributedCheckpointer.
279+ # Therefore the recipe needs to load the distributed checkpoint to restore the training
280+ # progress.
281+ if self ._enable_async_checkpointing :
282+ try :
283+ checkpoint_dict = (
284+ self ._checkpoint_client .load_distributed_checkpoint (
285+ self ._model ,
286+ (
287+ self ._optim_ckpt_wrapper
288+ if self ._optimizer_in_bwd
289+ else self ._optimizer
290+ ),
291+ )
292+ )
293+ except Exception as e :
294+ log .warning (
295+ f"Failed to load distributed checkpoint: { e } . Training will start from the base checkpoint."
296+ )
297+
298+ # Update the recipe state from the checkpoint state dict.
299+ self ._update_recipe_state (checkpoint_dict )
300+
284301 # initialize loss
285302 self ._loss_fn = config .instantiate (cfg .loss )
286303
@@ -547,6 +564,7 @@ def _setup_model(
547564 log ,
548565 f"Instantiating model and loading checkpoint took { time .perf_counter () - init_start :.2f} secs" ,
549566 )
567+
550568 if self ._is_rank_zero :
551569 memory_stats = training .get_memory_stats (device = self ._device )
552570 training .log_memory_stats (memory_stats )
@@ -661,95 +679,6 @@ def _setup_data(
661679
662680 return sampler , dataloader
663681
664- def save_checkpoint (
665- self ,
666- epoch : int ,
667- ) -> None :
668- """
669- Checkpoint the state of the recipe. The constructed checkpoint state dict
670- contains the following information:
671- - Model weights with key training.MODEL_KEY
672- - Relevant recipe state if training is not complete
673-
674- Checkpointer will save the model weights and recipe state in
675- different checkpoint files. To correctly resume training from an intermediate checkpoint,
676- the model weights and recipe state must be provided.
677- """
678- # final dict passed onto the checkpointer
679- checkpoint_dict = {}
680-
681- intermediate_checkpoint = epoch + 1 < self .total_epochs
682-
683- utils .log_rank_zero (
684- log ,
685- "Saving checkpoint. This may take some time. Retrieving full model state dict..." ,
686- )
687- start = time .perf_counter ()
688-
689- # To prevent GPU memory from spiking during checkpoint save,
690- # we consolidate the full model and optim state dicts on CPU for rank 0
691- cpu_state_dict = training .gather_cpu_state_dict (
692- self ._model .state_dict (),
693- self ._is_rank_zero ,
694- device = self ._device ,
695- )
696-
697- utils .log_rank_zero (
698- log ,
699- f"Getting full model state dict took { time .perf_counter () - start :.2f} secs" ,
700- )
701-
702- if intermediate_checkpoint :
703- start = time .perf_counter ()
704- utils .log_rank_zero (log , "Getting optimizer state dict..." )
705- if not self ._optimizer_in_bwd :
706- opt_state_dict = training .get_full_optimizer_state_dict (
707- self ._optimizer ,
708- self ._is_rank_zero ,
709- device = self ._device ,
710- )
711- else :
712- opt_state_dict = {}
713- for param , opt in self ._optim_ckpt_wrapper .optim_map .items ():
714- opt_state_dict [param ] = training .get_full_optimizer_state_dict (
715- opt , self ._is_rank_zero , device = self ._device
716- )
717- utils .log_rank_zero (
718- log ,
719- f"Getting optimizer state dict took { time .perf_counter () - start :.2f} secs" ,
720- )
721- else :
722- opt_state_dict = None
723-
724- # Now that we have the model and opt state dict, create the actual checkpoint dict
725- # to be sent to the checkpointer and ultimately written to file
726-
727- if self ._is_rank_zero :
728- start = time .perf_counter ()
729- checkpoint_dict .update ({training .MODEL_KEY : cpu_state_dict })
730-
731- # if training is in-progress, checkpoint the optimizer state and recipe state
732- # as well.
733- if intermediate_checkpoint :
734- checkpoint_dict .update (
735- {
736- training .OPT_KEY : opt_state_dict ,
737- training .SEED_KEY : self .seed ,
738- training .EPOCHS_KEY : self .epochs_run ,
739- training .TOTAL_EPOCHS_KEY : self .total_epochs ,
740- training .MAX_STEPS_KEY : self .max_steps_per_epoch ,
741- }
742- )
743-
744- self ._checkpointer .save_checkpoint (
745- checkpoint_dict ,
746- epoch = epoch ,
747- intermediate_checkpoint = intermediate_checkpoint ,
748- )
749- log .info (f"Saving checkpoint took { time .perf_counter () - start :.2f} secs" )
750-
751- torch .distributed .barrier ()
752-
753682 def train (self ) -> None :
754683 """
755684 The core training loop.
@@ -922,7 +851,21 @@ def train(self) -> None:
922851 self ._profiler .step ()
923852
924853 self .epochs_run += 1
925- self .save_checkpoint (epoch = curr_epoch )
854+ self ._checkpoint_client .save_checkpoint (
855+ model = self ._model ,
856+ optimizer = (
857+ self ._optimizer
858+ if not self ._optimizer_in_bwd
859+ else self ._optim_ckpt_wrapper
860+ ),
861+ training_progress = TrainingProgress (
862+ seed = self .seed ,
863+ epochs_run = self .epochs_run ,
864+ total_epochs = self .total_epochs ,
865+ max_steps_per_epoch = self .max_steps_per_epoch ,
866+ ),
867+ epoch = curr_epoch ,
868+ )
926869
927870 self ._profiler .stop ()
928871
0 commit comments