Skip to content

Commit b14ba65

Browse files
rahul_sarvam_airahul_sarvam_ai
authored andcommitted
save checkpoint every k steps
1 parent 8580959 commit b14ba65

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

recipes/full_finetune_distributed.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,21 @@ def train(self) -> None:
851851
self._profiler.step()
852852

853853
if self._save_interval is not None and ((idx + 1) / self._gradient_accumulation_steps) % self._save_interval == 0:
854-
self.save_checkpoint(epoch=curr_epoch, step=int((idx + 1) / self._gradient_accumulation_steps))
854+
self._checkpoint_client.save_checkpoint(
855+
model=self._model,
856+
optimizer=(
857+
self._optimizer
858+
if not self._optimizer_in_bwd
859+
else self._optim_ckpt_wrapper
860+
),
861+
training_progress=TrainingProgress(
862+
seed=self.seed,
863+
epochs_run=self.epochs_run,
864+
total_epochs=self.total_epochs,
865+
max_steps_per_epoch=self.max_steps_per_epoch,
866+
),
867+
epoch=curr_epoch,
868+
)
855869

856870
self.epochs_run += 1
857871
self._checkpoint_client.save_checkpoint(

0 commit comments

Comments
 (0)