Skip to content

Commit dcb9531

Browse files
Only merge model weights in LoRA recipe when save_adapter_weights_only=False (#1476)
1 parent 62b0c79 commit dcb9531

9 files changed

+285
-56
lines changed

recipes/lora_dpo_distributed.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,13 @@ def save_checkpoint(
529529
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
530530

531531
# merge the adapter weights and base weights to create the model checkpoint
532-
merged_state_dict = get_merged_lora_ckpt(
533-
cpu_state_dict,
534-
rank=self._lora_rank,
535-
alpha=self._lora_alpha,
536-
)
537-
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
532+
if not self._save_adapter_weights_only:
533+
merged_state_dict = get_merged_lora_ckpt(
534+
cpu_state_dict,
535+
rank=self._lora_rank,
536+
alpha=self._lora_alpha,
537+
)
538+
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
538539

539540
# if training is in-progress, checkpoint the optimizer state and recipe state
540541
# as well.

recipes/lora_dpo_single_device.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -410,22 +410,33 @@ def save_checkpoint(self, epoch: int) -> None:
410410
}
411411
)
412412

413-
# Move to CPU to avoid a copy on GPU
414-
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
415-
416-
# Construct the full state dict with LoRA weights merged into base LLM weights
417-
merged_state_dict = get_merged_lora_ckpt(
418-
state_dict,
419-
rank=self._lora_rank,
420-
alpha=self._lora_alpha,
421-
)
422-
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
423-
424-
# Construct the adapter weights
425413
adapter_key_filter = lambda x: x in self.adapter_params
426-
adapter_state_dict = {
427-
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
428-
}
414+
if not self._save_adapter_weights_only:
415+
# Construct the full state dict with LoRA weights merged into base LLM weights
416+
417+
# Move to CPU to avoid a copy on GPU
418+
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
419+
420+
# Construct the adapter weights
421+
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
422+
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
423+
adapter_state_dict = {
424+
k: v for k, v in state_dict.items() if adapter_key_filter(k)
425+
}
426+
427+
merged_state_dict = get_merged_lora_ckpt(
428+
state_dict,
429+
rank=self._lora_rank,
430+
alpha=self._lora_alpha,
431+
)
432+
433+
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
434+
else:
435+
# No need to merge state dict if we're only saving adapter weights
436+
adapter_state_dict = {
437+
k: v.cpu() for k, v in get_adapter_params(self._model).items()
438+
}
439+
429440
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
430441

431442
self._checkpointer.save_checkpoint(

recipes/lora_finetune_distributed.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -628,12 +628,13 @@ def save_checkpoint(
628628
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
629629

630630
# merge the adapter weights and base weights to create the model checkpoint
631-
merged_state_dict = get_merged_lora_ckpt(
632-
cpu_state_dict,
633-
rank=self._lora_rank,
634-
alpha=self._lora_alpha,
635-
)
636-
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
631+
if not self._save_adapter_weights_only:
632+
merged_state_dict = get_merged_lora_ckpt(
633+
cpu_state_dict,
634+
rank=self._lora_rank,
635+
alpha=self._lora_alpha,
636+
)
637+
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
637638

638639
# if training is in-progress, checkpoint the optimizer state and recipe state
639640
# as well.

recipes/lora_finetune_single_device.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -486,14 +486,16 @@ def _setup_data(
486486
dataset=ds,
487487
sampler=sampler,
488488
batch_size=batch_size,
489-
collate_fn=partial(
490-
padded_collate_sft,
491-
padding_idx=self._tokenizer.pad_id,
492-
ignore_idx=self._loss_fn.ignore_index,
493-
)
494-
if not packed
495-
else partial(
496-
padded_collate_packed,
489+
collate_fn=(
490+
partial(
491+
padded_collate_sft,
492+
padding_idx=self._tokenizer.pad_id,
493+
ignore_idx=self._loss_fn.ignore_index,
494+
)
495+
if not packed
496+
else partial(
497+
padded_collate_packed,
498+
)
497499
),
498500
)
499501

@@ -527,24 +529,32 @@ def save_checkpoint(self, epoch: int) -> None:
527529
}
528530
)
529531

530-
# Move to CPU to avoid a copy on GPU
531-
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
532+
if not self._save_adapter_weights_only:
533+
# Construct the full state dict with LoRA weights merged into base LLM weights
532534

533-
# Construct the adapter weights
534-
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
535-
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
536-
adapter_key_filter = lambda x: x in self.adapter_params
537-
adapter_state_dict = {
538-
k: v for k, v in state_dict.items() if adapter_key_filter(k)
539-
}
535+
# Move to CPU to avoid a copy on GPU
536+
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
540537

541-
# Construct the full state dict with LoRA weights merged into base LLM weights
542-
merged_state_dict = get_merged_lora_ckpt(
543-
state_dict,
544-
rank=self._lora_rank,
545-
alpha=self._lora_alpha,
546-
)
547-
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
538+
# Construct the adapter weights
539+
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
540+
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
541+
adapter_key_filter = lambda x: x in self.adapter_params
542+
adapter_state_dict = {
543+
k: v for k, v in state_dict.items() if adapter_key_filter(k)
544+
}
545+
546+
merged_state_dict = get_merged_lora_ckpt(
547+
state_dict,
548+
rank=self._lora_rank,
549+
alpha=self._lora_alpha,
550+
)
551+
552+
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
553+
else:
554+
# No need to merge state dict if we're only saving adapter weights
555+
adapter_state_dict = {
556+
k: v.cpu() for k, v in get_adapter_params(self._model).items()
557+
}
548558

549559
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
550560
adapter_config = {

tests/assets/stack_exchange_paired_tiny.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import runpy
9+
import sys
10+
from pathlib import Path
11+
12+
import pytest
13+
import torch
14+
from omegaconf import OmegaConf
15+
from tests.common import TUNE_PATH
16+
from tests.recipes.utils import (
17+
dummy_stack_exchange_dataset_config,
18+
MODEL_TEST_CONFIGS,
19+
write_hf_ckpt_config,
20+
)
21+
from tests.test_utils import (
22+
CKPT_MODEL_PATHS,
23+
gen_log_file_name,
24+
get_loss_values_from_metric_logger,
25+
)
26+
from torchtune import config
27+
28+
29+
class TestLoRADPOSingleDeviceRecipe:
30+
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
31+
return [
32+
"batch_size=8",
33+
"device=cpu",
34+
f"dtype={dtype_str}",
35+
"enable_activation_checkpointing=False",
36+
"dataset.train_on_input=False",
37+
"seed=9",
38+
f"epochs={epochs}",
39+
"max_steps_per_epoch=2",
40+
"optimizer.lr=2e-5",
41+
"log_every_n_steps=1",
42+
"gradient_accumulation_steps=1",
43+
"clip_grad_norm=100",
44+
"tokenizer.max_seq_len=512",
45+
] + dummy_stack_exchange_dataset_config()
46+
47+
@pytest.mark.parametrize("save_adapter_weights_only", [False, True])
48+
@pytest.mark.integration_test
49+
def test_training_state_on_resume(
50+
self, tmpdir, monkeypatch, save_adapter_weights_only
51+
):
52+
"""Test whether the recipe state is correctly updated on resume. Since this
53+
is model agnostic, we should run this on the small model only. The test
54+
consists of three stages:
55+
- Train a model for 2 epochs
56+
- Resume training after epoch 1
57+
- Make sure final loss matches the expected value of a model successfully resumed from a ckpt
58+
Unlike `tests.recipes.test_lora_finetune_single_device`, this test does not use pre-computed loss
59+
values to benchmark against. This test just ensures the loss values are identical when resuming.
60+
"""
61+
62+
ckpt = "llama2_hf"
63+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
64+
ckpt_dir = ckpt_path.parent
65+
log_file = gen_log_file_name(tmpdir)
66+
67+
# Config file needed for model conversion.
68+
# Create a second copy for training resume
69+
write_hf_ckpt_config(ckpt_dir)
70+
write_hf_ckpt_config(tmpdir)
71+
72+
# Train for two epochs
73+
cmd_1 = f"""
74+
tune run lora_dpo_single_device \
75+
--config llama2/7B_lora_dpo_single_device \
76+
output_dir={tmpdir} \
77+
checkpointer=torchtune.training.FullModelHFCheckpointer \
78+
checkpointer.checkpoint_dir='{ckpt_dir}' \
79+
checkpointer.checkpoint_files=[{ckpt_path}]\
80+
checkpointer.output_dir={tmpdir} \
81+
checkpointer.model_type=LLAMA2 \
82+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
83+
tokenizer.prompt_template=null \
84+
save_adapter_weights_only={save_adapter_weights_only} \
85+
metric_logger.filename={log_file} \
86+
""".split()
87+
88+
model_config = MODEL_TEST_CONFIGS["llama2_lora"]
89+
90+
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
91+
monkeypatch.setattr(sys, "argv", cmd_1)
92+
with pytest.raises(SystemExit, match=""):
93+
runpy.run_path(TUNE_PATH, run_name="__main__")
94+
95+
expected_loss_values = get_loss_values_from_metric_logger(log_file)
96+
97+
resumed_log_dir = (tmpdir / "resumed/").mkdir()
98+
resumed_log_file = gen_log_file_name(resumed_log_dir)
99+
# Resume training
100+
cmd_2 = f"""
101+
tune run lora_dpo_single_device \
102+
--config llama2/7B_lora_dpo_single_device \
103+
output_dir={tmpdir} \
104+
checkpointer=torchtune.training.FullModelHFCheckpointer \
105+
checkpointer.checkpoint_dir={tmpdir} \
106+
checkpointer.checkpoint_files=[{ckpt_path}]\
107+
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
108+
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
109+
checkpointer.output_dir={tmpdir} \
110+
checkpointer.model_type=LLAMA2 \
111+
resume_from_checkpoint=True \
112+
metric_logger.filename={resumed_log_file} \
113+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
114+
tokenizer.prompt_template=null \
115+
""".split()
116+
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
117+
monkeypatch.setattr(sys, "argv", cmd_2)
118+
with pytest.raises(SystemExit, match=""):
119+
runpy.run_path(TUNE_PATH, run_name="__main__")
120+
121+
# Second epoch only
122+
resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file)
123+
124+
torch.testing.assert_close(
125+
resumed_loss_values[:2], expected_loss_values[2:], rtol=1e-5, atol=1e-5
126+
)
127+
128+
@pytest.mark.integration_test
129+
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
130+
ckpt = "llama2_tune"
131+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
132+
ckpt_dir = ckpt_path.parent
133+
134+
cmd = f"""
135+
tune run lora_dpo_single_device \
136+
--config llama2/7B_lora_dpo_single_device \
137+
output_dir={tmpdir} \
138+
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
139+
checkpointer.checkpoint_dir='{ckpt_dir}' \
140+
checkpointer.checkpoint_files=[{ckpt_path}]\
141+
checkpointer.output_dir={tmpdir} \
142+
checkpointer.model_type=LLAMA2 \
143+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
144+
tokenizer.prompt_template=null \
145+
""".split()
146+
147+
model_config = MODEL_TEST_CONFIGS["llama2_lora"]
148+
149+
cmd = cmd + self._get_test_config_overrides() + model_config
150+
monkeypatch.setattr(sys, "argv", cmd)
151+
with pytest.raises(SystemExit, match=""):
152+
runpy.run_path(TUNE_PATH, run_name="__main__")
153+
154+
# Next load both the merged weights in a Llama2 base model
155+
# and the base model weights + trained adapter weights in the LoRA Llama 2 model
156+
# The results of calling forward on dummy inputs should be the same.
157+
inputs = torch.randint(low=0, high=32_000, size=(2, 100))
158+
159+
# Build LoRA model for loading base + adapter weights separately
160+
lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model)
161+
162+
# Build base llama2 model for loading merged weights
163+
base_llama2_config = MODEL_TEST_CONFIGS["llama2"]
164+
llama2_model = config.instantiate(
165+
OmegaConf.from_dotlist(base_llama2_config).model
166+
)
167+
168+
# Load base model and trained adapter weights into LoRA model and call fwd
169+
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
170+
lora_sd = torch.load(f, weights_only=True)
171+
with open(ckpt_path, "rb") as f:
172+
base_model_sd = torch.load(f, weights_only=True)
173+
lora_model.load_state_dict(lora_sd, strict=False)
174+
lora_model.load_state_dict(base_model_sd, strict=False)
175+
baseline_out = lora_model(inputs)
176+
177+
# Load merged final ckpt directly into llama2 and call fwd
178+
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
179+
sd = torch.load(f, weights_only=True)
180+
llama2_model.load_state_dict(sd)
181+
merged_ckpt_out = llama2_model(inputs)
182+
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)

tests/recipes/test_lora_finetune_distributed.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,21 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch):
9797
@pytest.mark.integration_test
9898
@gpu_test(gpu_count=2)
9999
@pytest.mark.parametrize(
100-
"config, model_type, ckpt_type",
100+
"config, model_type, ckpt_type, save_adapter_weights_only",
101101
[
102-
("llama2/7B_lora", "llama2", "hf"),
103-
("llama3/8B_lora", "llama3", "tune"),
102+
("llama2/7B_lora", "llama2", "hf", False),
103+
("llama3/8B_lora", "llama3", "tune", False),
104+
("llama2/7B_lora", "llama2", "hf", True),
104105
],
105106
)
106107
def test_training_state_on_resume(
107-
self, config, model_type, ckpt_type, tmpdir, monkeypatch
108+
self,
109+
config,
110+
model_type,
111+
ckpt_type,
112+
tmpdir,
113+
monkeypatch,
114+
save_adapter_weights_only,
108115
):
109116
"""Test whether the recipe state is correctly updated on resume. Since this
110117
is model agnostic, we should run this on the small model only. The test
@@ -139,6 +146,7 @@ def test_training_state_on_resume(
139146
checkpointer.model_type={model_type.upper()} \
140147
tokenizer.path='{tokenizer_path}' \
141148
tokenizer.prompt_template=null \
149+
save_adapter_weights_only={save_adapter_weights_only} \
142150
""".split()
143151

144152
model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]

0 commit comments

Comments
 (0)