@@ -508,7 +508,6 @@ def _setup_model(
508508 model ,
509509 model_state_dict ,
510510 self ._device ,
511- self ._is_rank_zero ,
512511 strict = True ,
513512 cpu_offload = fsdp_cpu_offload ,
514513 )
@@ -562,6 +561,7 @@ def _setup_optimizer(
562561 for param in opt_state_dict .keys ():
563562 try :
564563 training .load_from_full_optimizer_state_dict (
564+ self ._model ,
565565 self ._optim_ckpt_wrapper .state_dict ()[param ],
566566 opt_state_dict [param ],
567567 self ._device ,
@@ -577,6 +577,7 @@ def _setup_optimizer(
577577 optimizer = config .instantiate (cfg_optimizer , self ._model .parameters ())
578578 if opt_state_dict :
579579 training .load_from_full_optimizer_state_dict (
580+ self ._model ,
580581 optimizer ,
581582 opt_state_dict ,
582583 self ._device ,
@@ -667,7 +668,7 @@ def save_checkpoint(
667668 # To prevent GPU memory from spiking during checkpoint save,
668669 # we consolidate the full model and optim state dicts on CPU for rank 0
669670 cpu_state_dict = training .gather_cpu_state_dict (
670- self ._model . state_dict () ,
671+ self ._model ,
671672 self ._is_rank_zero ,
672673 device = self ._device ,
673674 )
@@ -682,6 +683,7 @@ def save_checkpoint(
682683 utils .log_rank_zero (log , "Getting optimizer state dict..." )
683684 if not self ._optimizer_in_bwd :
684685 opt_state_dict = training .get_full_optimizer_state_dict (
686+ self ._model ,
685687 self ._optimizer ,
686688 self ._is_rank_zero ,
687689 device = self ._device ,
@@ -690,7 +692,7 @@ def save_checkpoint(
690692 opt_state_dict = {}
691693 for param , opt in self ._optim_ckpt_wrapper .optim_map .items ():
692694 opt_state_dict [param ] = training .get_full_optimizer_state_dict (
693- opt , self ._is_rank_zero , device = self ._device
695+ self . _model , opt , self ._is_rank_zero , device = self ._device
694696 )
695697 utils .log_rank_zero (
696698 log ,
@@ -835,7 +837,9 @@ def train(self) -> None:
835837 if self ._optimizer_in_bwd :
836838 torch .distributed .all_reduce (num_tokens )
837839 torch .distributed .all_reduce (running_loss )
838- current_loss = current_loss / num_tokens
840+
841+ # We multiply by world_size to undo FSDP2 gradient normalization.
842+ current_loss = current_loss * (world_size / num_tokens )
839843
840844 current_loss .backward ()
841845
@@ -847,7 +851,8 @@ def train(self) -> None:
847851 # This will ensure that the logged loss matches what we're optimizing
848852 torch .distributed .all_reduce (running_loss )
849853 # Manually scale the gradients from unnormalized loss by total # of tokens
850- training .scale_grads (self ._model , 1 / num_tokens )
854+ # We multiply by world_size to undo FSDP2 gradient normalization.
855+ training .scale_grads (self ._model , world_size / num_tokens )
851856 if self ._clip_grad_norm is not None :
852857 grad_norm = torch .nn .utils .clip_grad_norm_ (
853858 self ._model .parameters (),
0 commit comments