Skip to content

Commit 78b7d0d

Browse files
authored
Merge branch 'pytorch:main' into patch-2
2 parents 3bc7237 + e420bc0 commit 78b7d0d

18 files changed

+895
-205
lines changed

recipes/dev/early_exit_finetune_distributed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ def _setup_model(
556556
model,
557557
model_state_dict,
558558
self._device,
559-
self._is_rank_zero,
560559
strict=True,
561560
cpu_offload=fsdp_cpu_offload,
562561
)
@@ -757,7 +756,7 @@ def save_checkpoint(
757756
# To prevent GPU memory from spiking during checkpoint save,
758757
# we consolidate the full model and optim state dicts on CPU for rank 0
759758
cpu_state_dict = training.gather_cpu_state_dict(
760-
self._model.state_dict(),
759+
self._model,
761760
self._is_rank_zero,
762761
device=self._device,
763762
)
@@ -773,6 +772,7 @@ def save_checkpoint(
773772
log.info("Getting optimizer state dict...")
774773
if not self._optimizer_in_bwd:
775774
opt_state_dict = training.get_full_optimizer_state_dict(
775+
self._model,
776776
self._optimizer,
777777
self._is_rank_zero,
778778
device=self._device,
@@ -781,7 +781,7 @@ def save_checkpoint(
781781
opt_state_dict = {}
782782
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
783783
opt_state_dict[param] = training.get_full_optimizer_state_dict(
784-
opt, self._is_rank_zero, device=self._device
784+
self._model, opt, self._is_rank_zero, device=self._device
785785
)
786786
if self._is_rank_zero:
787787
log.info(

recipes/full_finetune_distributed.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,6 @@ def _setup_model(
547547
model,
548548
model_state_dict,
549549
self._device,
550-
self._is_rank_zero,
551550
strict=True,
552551
cpu_offload=fsdp_cpu_offload,
553552
)
@@ -602,6 +601,7 @@ def _setup_optimizer(
602601
for param in opt_state_dict.keys():
603602
try:
604603
training.load_from_full_optimizer_state_dict(
604+
self._model,
605605
self._optim_ckpt_wrapper.state_dict()[param],
606606
opt_state_dict[param],
607607
self._device,
@@ -617,6 +617,7 @@ def _setup_optimizer(
617617
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
618618
if opt_state_dict:
619619
training.load_from_full_optimizer_state_dict(
620+
self._model,
620621
optimizer,
621622
opt_state_dict,
622623
self._device,
@@ -765,7 +766,9 @@ def train(self) -> None:
765766
if self._optimizer_in_bwd:
766767
torch.distributed.all_reduce(num_tokens)
767768
torch.distributed.all_reduce(running_loss)
768-
current_loss = current_loss / num_tokens
769+
770+
# We multiply by world_size to undo FSDP2 gradient normalization.
771+
current_loss = current_loss * (world_size / num_tokens)
769772

770773
current_loss.backward()
771774

@@ -777,7 +780,8 @@ def train(self) -> None:
777780
# This will ensure that the logged loss matches what we're optimizing
778781
torch.distributed.all_reduce(running_loss)
779782
# Manually scale the gradients from unnormalized loss by total # of tokens
780-
training.scale_grads(self._model, 1 / num_tokens)
783+
# We multiply by world_size to undo FSDP2 gradient normalization.
784+
training.scale_grads(self._model, world_size / num_tokens)
781785
if self._clip_grad_norm is not None:
782786
grad_norm = torch.nn.utils.clip_grad_norm_(
783787
self._model.parameters(),

recipes/knowledge_distillation_distributed.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,6 @@ def _setup_model(
461461
model,
462462
lora_weights_state_dict,
463463
self._device,
464-
self._is_rank_zero,
465464
cpu_offload=fsdp_cpu_offload,
466465
)
467466
else:
@@ -486,7 +485,6 @@ def _setup_model(
486485
model,
487486
base_model_state_dict,
488487
self._device,
489-
self._is_rank_zero,
490488
cpu_offload=fsdp_cpu_offload,
491489
)
492490
for m in model.modules():
@@ -574,7 +572,6 @@ def _setup_teacher_model(
574572
model,
575573
model_state_dict,
576574
self._device,
577-
self._is_rank_zero,
578575
strict=True,
579576
cpu_offload=fsdp_cpu_offload,
580577
)
@@ -611,6 +608,7 @@ def _setup_optimizer(
611608
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
612609
if opt_state_dict:
613610
training.load_from_full_optimizer_state_dict(
611+
self._model,
614612
optimizer,
615613
opt_state_dict,
616614
self._device,
@@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None:
705703
# To prevent GPU memory from spiking during checkpoint save,
706704
# we consolidate the full model and optim state dicts on CPU for rank 0
707705
cpu_state_dict = training.gather_cpu_state_dict(
708-
self._model.state_dict(),
706+
self._model,
709707
self._is_rank_zero,
710708
device=self._device,
711709
)
712710

713711
if intermediate_checkpoint:
714712
opt_state_dict = training.get_full_optimizer_state_dict(
713+
self._model,
715714
self._optimizer,
716715
self._is_rank_zero,
717716
device=self._device,
@@ -770,7 +769,6 @@ def save_checkpoint(self, epoch: int) -> None:
770769
def _loss_step(
771770
self, batch: Dict[str, torch.Tensor]
772771
) -> (torch.Tensor, torch.Tensor):
773-
774772
# Both are shape [b, s]
775773
tokens, labels = batch["tokens"], batch["labels"]
776774

@@ -876,7 +874,8 @@ def train(self) -> None:
876874
torch.distributed.all_reduce(running_class_loss)
877875
torch.distributed.all_reduce(running_kd_loss)
878876
# Manually scale the gradients from unnormalized loss by total # of tokens
879-
training.scale_grads(self._model, 1 / num_tokens)
877+
# We multiply by world_size to undo FSDP2 gradient normalization.
878+
training.scale_grads(self._model, world_size / num_tokens)
880879
class_loss_to_log = running_class_loss.item() / num_tokens
881880
kd_loss_to_log = running_kd_loss.item() / num_tokens
882881
self._optimizer.step()

recipes/lora_dpo_distributed.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,6 @@ def _setup_model(
385385
model,
386386
lora_weights_state_dict,
387387
self._device,
388-
self._is_rank_zero,
389388
cpu_offload=fsdp_cpu_offload,
390389
)
391390
else:
@@ -410,7 +409,6 @@ def _setup_model(
410409
model,
411410
base_model_state_dict,
412411
self._device,
413-
self._is_rank_zero,
414412
cpu_offload=fsdp_cpu_offload,
415413
)
416414
is_dora = False
@@ -458,6 +456,7 @@ def _setup_optimizer(
458456
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
459457
if opt_state_dict:
460458
training.load_from_full_optimizer_state_dict(
459+
self._model,
461460
optimizer,
462461
opt_state_dict,
463462
self._device,
@@ -546,17 +545,15 @@ def save_checkpoint(
546545
intermediate_checkpoint = epoch + 1 < self.total_epochs
547546
# To prevent GPU memory from spiking during checkpoint save,
548547
# we consolidate the full model and optim state dicts on CPU for rank 0
549-
state_dict = self._model.state_dict()
550-
if self._save_adapter_weights_only:
551-
state_dict = get_adapter_state_dict(state_dict, device=None)
552-
553548
cpu_state_dict = training.gather_cpu_state_dict(
554-
state_dict,
549+
self._model,
555550
self._is_rank_zero,
556551
device=self._device,
552+
adapter_weights_only=self._save_adapter_weights_only,
557553
)
558554
if intermediate_checkpoint:
559555
opt_state_dict = training.get_full_optimizer_state_dict(
556+
self._model,
560557
self._optimizer,
561558
self._is_rank_zero,
562559
device=self._device,

recipes/lora_finetune_distributed.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,6 @@ def _setup_model(
480480
model,
481481
lora_weights_state_dict,
482482
self._device,
483-
self._is_rank_zero,
484483
cpu_offload=fsdp_cpu_offload,
485484
)
486485
else:
@@ -505,7 +504,6 @@ def _setup_model(
505504
model,
506505
base_model_state_dict,
507506
self._device,
508-
self._is_rank_zero,
509507
cpu_offload=fsdp_cpu_offload,
510508
)
511509
for m in model.modules():
@@ -549,6 +547,7 @@ def _setup_optimizer(
549547
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
550548
if opt_state_dict:
551549
training.load_from_full_optimizer_state_dict(
550+
self._model,
552551
optimizer,
553552
opt_state_dict,
554553
self._device,
@@ -656,14 +655,11 @@ def save_checkpoint(
656655

657656
# To prevent GPU memory from spiking during checkpoint save,
658657
# we consolidate the full model and optim state dicts on CPU for rank 0
659-
state_dict = self._model.state_dict()
660-
if self._save_adapter_weights_only:
661-
state_dict = get_adapter_state_dict(state_dict, device=None)
662-
663658
cpu_state_dict = training.gather_cpu_state_dict(
664-
state_dict,
659+
self._model,
665660
self._is_rank_zero,
666661
device=self._device,
662+
adapter_weights_only=self._save_adapter_weights_only,
667663
)
668664
utils.log_rank_zero(
669665
log,
@@ -673,6 +669,7 @@ def save_checkpoint(
673669
if intermediate_checkpoint:
674670
utils.log_rank_zero(log, "Retrieving optimizer state dict...")
675671
opt_state_dict = training.get_full_optimizer_state_dict(
672+
self._model,
676673
self._optimizer,
677674
self._is_rank_zero,
678675
device=self._device,
@@ -825,7 +822,8 @@ def train(self) -> None:
825822
# This will ensure that the logged loss matches what we're optimizing
826823
torch.distributed.all_reduce(running_loss)
827824
# Manually scale the gradients from unnormalized loss by total # of tokens
828-
training.scale_grads(self._model, 1 / num_tokens)
825+
# We multiply by world_size to undo FSDP2 gradient normalization.
826+
training.scale_grads(self._model, world_size / num_tokens)
829827
if self._clip_grad_norm is not None:
830828
grad_norm = torch.nn.utils.clip_grad_norm_(
831829
self._model.parameters(),

recipes/lora_finetune_distributed_multi_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,6 @@ def _setup_model(
473473
model,
474474
lora_weights_state_dict,
475475
self._device,
476-
self._is_rank_zero,
477476
cpu_offload=fsdp_cpu_offload,
478477
)
479478
else:
@@ -500,7 +499,6 @@ def _setup_model(
500499
model,
501500
base_model_state_dict,
502501
self._device,
503-
self._is_rank_zero,
504502
cpu_offload=fsdp_cpu_offload,
505503
)
506504
for m in model.modules():
@@ -853,7 +851,8 @@ def train(self) -> None:
853851
# This will ensure that the logged loss matches what we're optimizing
854852
torch.distributed.all_reduce(running_loss)
855853
# Manually scale the gradients from unnormalized loss by total # of tokens
856-
training.scale_grads(self._model, 1 / num_tokens)
854+
# We multiply by world_size to undo FSDP2 gradient normalization.
855+
training.scale_grads(self._model, world_size / num_tokens)
857856
if self._clip_grad_norm is not None:
858857
grad_norm = torch.nn.utils.clip_grad_norm_(
859858
self._model.parameters(),

recipes/qat_distributed.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,6 @@ def _setup_model(
508508
model,
509509
model_state_dict,
510510
self._device,
511-
self._is_rank_zero,
512511
strict=True,
513512
cpu_offload=fsdp_cpu_offload,
514513
)
@@ -562,6 +561,7 @@ def _setup_optimizer(
562561
for param in opt_state_dict.keys():
563562
try:
564563
training.load_from_full_optimizer_state_dict(
564+
self._model,
565565
self._optim_ckpt_wrapper.state_dict()[param],
566566
opt_state_dict[param],
567567
self._device,
@@ -577,6 +577,7 @@ def _setup_optimizer(
577577
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
578578
if opt_state_dict:
579579
training.load_from_full_optimizer_state_dict(
580+
self._model,
580581
optimizer,
581582
opt_state_dict,
582583
self._device,
@@ -667,7 +668,7 @@ def save_checkpoint(
667668
# To prevent GPU memory from spiking during checkpoint save,
668669
# we consolidate the full model and optim state dicts on CPU for rank 0
669670
cpu_state_dict = training.gather_cpu_state_dict(
670-
self._model.state_dict(),
671+
self._model,
671672
self._is_rank_zero,
672673
device=self._device,
673674
)
@@ -682,6 +683,7 @@ def save_checkpoint(
682683
utils.log_rank_zero(log, "Getting optimizer state dict...")
683684
if not self._optimizer_in_bwd:
684685
opt_state_dict = training.get_full_optimizer_state_dict(
686+
self._model,
685687
self._optimizer,
686688
self._is_rank_zero,
687689
device=self._device,
@@ -690,7 +692,7 @@ def save_checkpoint(
690692
opt_state_dict = {}
691693
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
692694
opt_state_dict[param] = training.get_full_optimizer_state_dict(
693-
opt, self._is_rank_zero, device=self._device
695+
self._model, opt, self._is_rank_zero, device=self._device
694696
)
695697
utils.log_rank_zero(
696698
log,
@@ -835,7 +837,9 @@ def train(self) -> None:
835837
if self._optimizer_in_bwd:
836838
torch.distributed.all_reduce(num_tokens)
837839
torch.distributed.all_reduce(running_loss)
838-
current_loss = current_loss / num_tokens
840+
841+
# We multiply by world_size to undo FSDP2 gradient normalization.
842+
current_loss = current_loss * (world_size / num_tokens)
839843

840844
current_loss.backward()
841845

@@ -847,7 +851,8 @@ def train(self) -> None:
847851
# This will ensure that the logged loss matches what we're optimizing
848852
torch.distributed.all_reduce(running_loss)
849853
# Manually scale the gradients from unnormalized loss by total # of tokens
850-
training.scale_grads(self._model, 1 / num_tokens)
854+
# We multiply by world_size to undo FSDP2 gradient normalization.
855+
training.scale_grads(self._model, world_size / num_tokens)
851856
if self._clip_grad_norm is not None:
852857
grad_norm = torch.nn.utils.clip_grad_norm_(
853858
self._model.parameters(),

0 commit comments

Comments
 (0)