File tree Expand file tree Collapse file tree 1 file changed +15
-1
lines changed
Expand file tree Collapse file tree 1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -851,7 +851,21 @@ def train(self) -> None:
851851 self ._profiler .step ()
852852
853853 if self ._save_interval is not None and ((idx + 1 ) / self ._gradient_accumulation_steps ) % self ._save_interval == 0 :
854- self .save_checkpoint (epoch = curr_epoch , step = int ((idx + 1 ) / self ._gradient_accumulation_steps ))
854+ self ._checkpoint_client .save_checkpoint (
855+ model = self ._model ,
856+ optimizer = (
857+ self ._optimizer
858+ if not self ._optimizer_in_bwd
859+ else self ._optim_ckpt_wrapper
860+ ),
861+ training_progress = TrainingProgress (
862+ seed = self .seed ,
863+ epochs_run = self .epochs_run ,
864+ total_epochs = self .total_epochs ,
865+ max_steps_per_epoch = self .max_steps_per_epoch ,
866+ ),
867+ epoch = curr_epoch ,
868+ )
855869
856870 self .epochs_run += 1
857871 self ._checkpoint_client .save_checkpoint (
You can’t perform that action at this time.
0 commit comments