Skip to content

Commit c2c6f4a

Browse files
saumishrSaurabh Mishra
andauthored
Faster intermediate checkpoints with DCP async save in TorchTune (#2006)
Co-authored-by: Saurabh Mishra <[email protected]>
1 parent 096881d commit c2c6f4a

19 files changed

+1178
-179
lines changed

recipes/full_finetune_distributed.py

Lines changed: 50 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
from torchtune.recipe_interfaces import FTRecipeInterface
2727
from torchtune.training import DummyProfiler, PROFILER_KEY
2828
from torchtune.training.activations import apply_selective_activation_checkpointing
29+
from torchtune.training.checkpointing._checkpoint_client import (
30+
CheckpointClient,
31+
TrainingProgress,
32+
)
2933
from torchtune.training.lr_schedulers import get_lr
3034

3135
from tqdm import tqdm
@@ -138,9 +142,11 @@ def __init__(self, cfg: DictConfig) -> None:
138142

139143
# Training cfg
140144
self._resume_from_checkpoint = cfg.resume_from_checkpoint
145+
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
141146
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
142147
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
143148
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
149+
self._checkpoint_client = CheckpointClient(cfg)
144150

145151
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
146152
if self._optimizer_in_bwd:
@@ -189,21 +195,6 @@ def __init__(self, cfg: DictConfig) -> None:
189195
self.max_steps_per_epoch = cfg.max_steps_per_epoch
190196
self.global_step = 0
191197

192-
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
193-
"""
194-
Extract the checkpoint state from file and validate. If resume_from_checkpoint
195-
is True, this also includes the recipe state.
196-
"""
197-
self._checkpointer = config.instantiate(
198-
cfg_checkpointer,
199-
resume_from_checkpoint=self._resume_from_checkpoint,
200-
)
201-
checkpoint_dict = self._checkpointer.load_checkpoint()
202-
203-
if self._resume_from_checkpoint:
204-
self._update_recipe_state(checkpoint_dict)
205-
return checkpoint_dict
206-
207198
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
208199
"""
209200
Updates the recipe state from checkpoint.
@@ -255,7 +246,8 @@ def setup(self, cfg: DictConfig) -> None:
255246
# log config with parameter override
256247
self._metric_logger.log_config(cfg)
257248

258-
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
249+
# Load the base model
250+
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
259251

260252
self._compile = cfg.get("compile", False)
261253
self._model = self._setup_model(
@@ -276,11 +268,36 @@ def setup(self, cfg: DictConfig) -> None:
276268
optimizer_in_bwd=self._optimizer_in_bwd,
277269
opt_state_dict=(
278270
checkpoint_dict[training.OPT_KEY]
279-
if self._resume_from_checkpoint
271+
if training.OPT_KEY in checkpoint_dict
280272
else None
281273
),
282274
)
283275

276+
if self._resume_from_checkpoint:
277+
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
278+
# using the DistributedCheckpointer.
279+
# Therefore the recipe needs to load the distributed checkpoint to restore the training
280+
# progress.
281+
if self._enable_async_checkpointing:
282+
try:
283+
checkpoint_dict = (
284+
self._checkpoint_client.load_distributed_checkpoint(
285+
self._model,
286+
(
287+
self._optim_ckpt_wrapper
288+
if self._optimizer_in_bwd
289+
else self._optimizer
290+
),
291+
)
292+
)
293+
except Exception as e:
294+
log.warning(
295+
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
296+
)
297+
298+
# Update the recipe state from the checkpoint state dict.
299+
self._update_recipe_state(checkpoint_dict)
300+
284301
# initialize loss
285302
self._loss_fn = config.instantiate(cfg.loss)
286303

@@ -547,6 +564,7 @@ def _setup_model(
547564
log,
548565
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
549566
)
567+
550568
if self._is_rank_zero:
551569
memory_stats = training.get_memory_stats(device=self._device)
552570
training.log_memory_stats(memory_stats)
@@ -661,95 +679,6 @@ def _setup_data(
661679

662680
return sampler, dataloader
663681

664-
def save_checkpoint(
665-
self,
666-
epoch: int,
667-
) -> None:
668-
"""
669-
Checkpoint the state of the recipe. The constructed checkpoint state dict
670-
contains the following information:
671-
- Model weights with key training.MODEL_KEY
672-
- Relevant recipe state if training is not complete
673-
674-
Checkpointer will save the model weights and recipe state in
675-
different checkpoint files. To correctly resume training from an intermediate checkpoint,
676-
the model weights and recipe state must be provided.
677-
"""
678-
# final dict passed onto the checkpointer
679-
checkpoint_dict = {}
680-
681-
intermediate_checkpoint = epoch + 1 < self.total_epochs
682-
683-
utils.log_rank_zero(
684-
log,
685-
"Saving checkpoint. This may take some time. Retrieving full model state dict...",
686-
)
687-
start = time.perf_counter()
688-
689-
# To prevent GPU memory from spiking during checkpoint save,
690-
# we consolidate the full model and optim state dicts on CPU for rank 0
691-
cpu_state_dict = training.gather_cpu_state_dict(
692-
self._model.state_dict(),
693-
self._is_rank_zero,
694-
device=self._device,
695-
)
696-
697-
utils.log_rank_zero(
698-
log,
699-
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs",
700-
)
701-
702-
if intermediate_checkpoint:
703-
start = time.perf_counter()
704-
utils.log_rank_zero(log, "Getting optimizer state dict...")
705-
if not self._optimizer_in_bwd:
706-
opt_state_dict = training.get_full_optimizer_state_dict(
707-
self._optimizer,
708-
self._is_rank_zero,
709-
device=self._device,
710-
)
711-
else:
712-
opt_state_dict = {}
713-
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
714-
opt_state_dict[param] = training.get_full_optimizer_state_dict(
715-
opt, self._is_rank_zero, device=self._device
716-
)
717-
utils.log_rank_zero(
718-
log,
719-
f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs",
720-
)
721-
else:
722-
opt_state_dict = None
723-
724-
# Now that we have the model and opt state dict, create the actual checkpoint dict
725-
# to be sent to the checkpointer and ultimately written to file
726-
727-
if self._is_rank_zero:
728-
start = time.perf_counter()
729-
checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})
730-
731-
# if training is in-progress, checkpoint the optimizer state and recipe state
732-
# as well.
733-
if intermediate_checkpoint:
734-
checkpoint_dict.update(
735-
{
736-
training.OPT_KEY: opt_state_dict,
737-
training.SEED_KEY: self.seed,
738-
training.EPOCHS_KEY: self.epochs_run,
739-
training.TOTAL_EPOCHS_KEY: self.total_epochs,
740-
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
741-
}
742-
)
743-
744-
self._checkpointer.save_checkpoint(
745-
checkpoint_dict,
746-
epoch=epoch,
747-
intermediate_checkpoint=intermediate_checkpoint,
748-
)
749-
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")
750-
751-
torch.distributed.barrier()
752-
753682
def train(self) -> None:
754683
"""
755684
The core training loop.
@@ -922,7 +851,21 @@ def train(self) -> None:
922851
self._profiler.step()
923852

924853
self.epochs_run += 1
925-
self.save_checkpoint(epoch=curr_epoch)
854+
self._checkpoint_client.save_checkpoint(
855+
model=self._model,
856+
optimizer=(
857+
self._optimizer
858+
if not self._optimizer_in_bwd
859+
else self._optim_ckpt_wrapper
860+
),
861+
training_progress=TrainingProgress(
862+
seed=self.seed,
863+
epochs_run=self.epochs_run,
864+
total_epochs=self.total_epochs,
865+
max_steps_per_epoch=self.max_steps_per_epoch,
866+
),
867+
epoch=curr_epoch,
868+
)
926869

927870
self._profiler.stop()
928871

recipes/full_finetune_single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
197197
"""
198198
self._checkpointer = config.instantiate(
199199
cfg_checkpointer,
200-
resume_from_checkpoint=self._resume_from_checkpoint,
200+
should_load_recipe_state=self._resume_from_checkpoint,
201201
)
202202
checkpoint_dict = self._checkpointer.load_checkpoint()
203203

recipes/knowledge_distillation_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
149149
"""
150150
self._checkpointer = config.instantiate(
151151
cfg_checkpointer,
152-
resume_from_checkpoint=self._resume_from_checkpoint,
152+
should_load_recipe_state=self._resume_from_checkpoint,
153153
)
154154
checkpoint_dict = self._checkpointer.load_checkpoint()
155155

recipes/knowledge_distillation_single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
147147
"""
148148
self._checkpointer = config.instantiate(
149149
cfg_checkpointer,
150-
resume_from_checkpoint=self._resume_from_checkpoint,
150+
should_load_recipe_state=self._resume_from_checkpoint,
151151
)
152152
checkpoint_dict = self._checkpointer.load_checkpoint()
153153

recipes/lora_dpo_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
188188
"""
189189
self._checkpointer = config.instantiate(
190190
cfg_checkpointer,
191-
resume_from_checkpoint=self._resume_from_checkpoint,
191+
should_load_recipe_state=self._resume_from_checkpoint,
192192
)
193193
checkpoint_dict = self._checkpointer.load_checkpoint()
194194

recipes/lora_dpo_single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
145145
"""
146146
self._checkpointer = config.instantiate(
147147
cfg_checkpointer,
148-
resume_from_checkpoint=self._resume_from_checkpoint,
148+
should_load_recipe_state=self._resume_from_checkpoint,
149149
)
150150
checkpoint_dict = self._checkpointer.load_checkpoint()
151151

recipes/lora_finetune_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
197197
"""
198198
self._checkpointer = config.instantiate(
199199
cfg_checkpointer,
200-
resume_from_checkpoint=self._resume_from_checkpoint,
200+
should_load_recipe_state=self._resume_from_checkpoint,
201201
)
202202
checkpoint_dict = self._checkpointer.load_checkpoint()
203203

recipes/lora_finetune_single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
188188
"""
189189
self._checkpointer = config.instantiate(
190190
cfg_checkpointer,
191-
resume_from_checkpoint=self._resume_from_checkpoint,
191+
should_load_recipe_state=self._resume_from_checkpoint,
192192
)
193193
checkpoint_dict = self._checkpointer.load_checkpoint()
194194

recipes/ppo_full_finetune_single_device.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,22 +377,22 @@ def _setup_checkpointers(
377377

378378
policy_checkpointer = config.instantiate(
379379
policy_cfg,
380-
resume_from_checkpoint=self._resume_from_checkpoint,
380+
should_load_recipe_state=self._resume_from_checkpoint,
381381
)
382382

383383
ref_policy_checkpointer = config.instantiate(
384384
ref_policy_cfg,
385-
resume_from_checkpoint=False,
385+
should_load_recipe_state=False,
386386
)
387387

388388
value_checkpointer = config.instantiate(
389389
value_cfg,
390-
resume_from_checkpoint=False,
390+
should_load_recipe_state=False,
391391
)
392392

393393
reward_checkpointer = config.instantiate(
394394
reward_cfg,
395-
resume_from_checkpoint=False,
395+
should_load_recipe_state=False,
396396
)
397397

398398
return (

recipes/qat_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
209209
"""
210210
self._checkpointer = config.instantiate(
211211
cfg_checkpointer,
212-
resume_from_checkpoint=self._resume_from_checkpoint,
212+
should_load_recipe_state=self._resume_from_checkpoint,
213213
)
214214
checkpoint_dict = self._checkpointer.load_checkpoint()
215215

0 commit comments

Comments
 (0)