Skip to content

Commit 30687d3

Browse files
SalmanMohammadiFelipe Mello
authored andcommitted
Fix adapter_config.json saving in DPO recipes (#2162)
1 parent 8600c49 commit 30687d3

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

recipes/lora_dpo_distributed.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DoRALinear,
2727
get_adapter_params,
2828
get_adapter_state_dict,
29+
get_lora_module_names,
2930
get_merged_lora_ckpt,
3031
LoRALinear,
3132
set_trainable_params,
@@ -595,6 +596,17 @@ def save_checkpoint(
595596
}
596597
)
597598

599+
adapter_config = {
600+
"r": self._lora_rank,
601+
"lora_alpha": self._lora_alpha,
602+
"target_modules": get_lora_module_names(
603+
self._lora_attn_modules,
604+
self._apply_lora_to_mlp,
605+
self._apply_lora_to_output,
606+
),
607+
"peft_type": "LORA",
608+
}
609+
checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config})
598610
self._checkpointer.save_checkpoint(
599611
checkpoint_dict,
600612
epoch=epoch,

recipes/lora_dpo_single_device.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
disable_adapter,
2525
get_adapter_params,
2626
get_adapter_state_dict,
27+
get_lora_module_names,
2728
get_merged_lora_ckpt,
2829
set_trainable_params,
2930
validate_missing_and_unexpected_for_lora,
@@ -448,6 +449,18 @@ def save_checkpoint(self, epoch: int) -> None:
448449

449450
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
450451

452+
adapter_config = {
453+
"r": self._lora_rank,
454+
"lora_alpha": self._lora_alpha,
455+
"target_modules": get_lora_module_names(
456+
self._lora_attn_modules,
457+
self._apply_lora_to_mlp,
458+
self._apply_lora_to_output,
459+
),
460+
"peft_type": "LORA",
461+
}
462+
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
463+
451464
self._checkpointer.save_checkpoint(
452465
ckpt_dict,
453466
epoch=epoch,

0 commit comments

Comments
 (0)