Skip to content

Commit 4991014

Browse files
updating using adapter only
1 parent 05620fe commit 4991014

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

recipes/lora_dpo_single_device.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ def save_checkpoint(self, epoch: int) -> None:
434434
else:
435435
# No need to merge state dict if we're only saving adapter weights
436436
adapter_state_dict = {
437-
k: v
438-
for k, v in self._model.state_dict().items()
439-
if adapter_key_filter(k)
437+
k: v.cpu() for k, v in get_adapter_params(self._model).items()
440438
}
441439

442440
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

recipes/lora_finetune_single_device.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,6 @@ def save_checkpoint(self, epoch: int) -> None:
529529
}
530530
)
531531

532-
adapter_key_filter = lambda x: x in self.adapter_params
533532
if not self._save_adapter_weights_only:
534533
# Construct the full state dict with LoRA weights merged into base LLM weights
535534

@@ -539,6 +538,7 @@ def save_checkpoint(self, epoch: int) -> None:
539538
# Construct the adapter weights
540539
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
541540
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
541+
adapter_key_filter = lambda x: x in self.adapter_params
542542
adapter_state_dict = {
543543
k: v for k, v in state_dict.items() if adapter_key_filter(k)
544544
}
@@ -553,9 +553,7 @@ def save_checkpoint(self, epoch: int) -> None:
553553
else:
554554
# No need to merge state dict if we're only saving adapter weights
555555
adapter_state_dict = {
556-
k: v
557-
for k, v in self._model.state_dict().items()
558-
if adapter_key_filter(k)
556+
k: v.cpu() for k, v in get_adapter_params(self._model).items()
559557
}
560558

561559
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

0 commit comments

Comments
 (0)