Skip to content

Commit 515efbe

Browse files
Prevent OOM during checkpoint save on colab for llama3-8b qlora recipe (#1315)
1 parent 66590b4 commit 515efbe

File tree

5 files changed

+153
-15
lines changed

5 files changed

+153
-15
lines changed

recipes/configs/llama3/8B_qlora_single_device.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,6 @@ profiler:
104104
warmup_steps: 5
105105
active_steps: 2
106106
num_cycles: 1
107+
108+
# For colab use True
109+
low_cpu_ram: False

recipes/lora_finetune_single_device.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from warnings import warn
1313

1414
import torch
15+
import torchtune.modules.common_utils as common_utils
1516
from omegaconf import DictConfig, ListConfig
1617

1718
from torch import nn
@@ -213,6 +214,10 @@ def setup(self, cfg: DictConfig) -> None:
213214
self._compile = cfg.compile
214215
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
215216

217+
# hack to toggle to the low cpu ram version of the reparametrize_as_dtype
218+
# hook based on the config.
219+
common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False)
220+
216221
# set up model
217222
self._model = self._setup_model(
218223
cfg_model=cfg.model,
@@ -525,6 +530,14 @@ def save_checkpoint(self, epoch: int) -> None:
525530
# Move to CPU to avoid a copy on GPU
526531
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
527532

533+
# Construct the adapter weights
534+
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
535+
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
536+
adapter_key_filter = lambda x: x in self.adapter_params
537+
adapter_state_dict = {
538+
k: v for k, v in state_dict.items() if adapter_key_filter(k)
539+
}
540+
528541
# Construct the full state dict with LoRA weights merged into base LLM weights
529542
merged_state_dict = get_merged_lora_ckpt(
530543
state_dict,
@@ -533,11 +546,6 @@ def save_checkpoint(self, epoch: int) -> None:
533546
)
534547
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
535548

536-
# Construct the adapter weights
537-
adapter_key_filter = lambda x: x in self.adapter_params
538-
adapter_state_dict = {
539-
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
540-
}
541549
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
542550
adapter_config = {
543551
"r": self._lora_rank,
@@ -698,7 +706,14 @@ def train(self) -> None:
698706
prof.step()
699707

700708
self.epochs_run += 1
709+
start_save_checkpoint = time.perf_counter()
710+
log.info("Starting checkpoint save...")
701711
self.save_checkpoint(epoch=curr_epoch)
712+
log.info(
713+
"Checkpoint saved in {:.2f} seconds.".format(
714+
time.perf_counter() - start_save_checkpoint
715+
)
716+
)
702717

703718
def cleanup(self) -> None:
704719
self._metric_logger.close()

torchtune/models/llama3/_component_builders.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
TransformerSelfAttentionLayer,
2323
)
2424

25-
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
25+
from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks
2626

2727
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
2828

@@ -256,9 +256,7 @@ def lora_llama3(
256256
if quantize_base:
257257
# For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly
258258
# so as to not increase peak memory
259-
model._register_state_dict_hook(
260-
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True)
261-
)
259+
_register_reparametrize_state_dict_hooks(model)
262260

263261
return model
264262

torchtune/modules/common_utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,20 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import mmap
8+
import sys
9+
from collections import OrderedDict
10+
from functools import partial
711
from typing import Any, Dict, Tuple
812

913
import torch
1014

1115
import torch.nn as nn
16+
from torch._subclasses.fake_tensor import FakeTensorConverter, FakeTensorMode
1217
from torchao.dtypes.nf4tensor import NF4Tensor
1318

19+
_use_low_cpu_ram: bool = False
20+
1421

1522
def reparametrize_as_dtype_state_dict_post_hook(
1623
model: nn.Module,
@@ -48,3 +55,111 @@ def reparametrize_as_dtype_state_dict_post_hook(
4855
state_dict[k] = v.to(dtype)
4956
if offload_to_cpu:
5057
state_dict[k] = state_dict[k].cpu()
58+
59+
60+
def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
61+
model: nn.Module,
62+
state_dict: Dict[str, Any],
63+
*args: Tuple[Any, ...],
64+
dtype: torch.dtype = torch.bfloat16,
65+
offload_to_cpu: bool = True,
66+
**kwargs: Dict[Any, Any],
67+
):
68+
"""
69+
A state_dict hook that replaces NF4 tensors with their restored
70+
higher-precision weight and optionally offloads the restored weight to CPU.
71+
Use this hook to avoid increased peak GPU memory usage during checkpoint
72+
save when training with QLoRA.
73+
74+
This hook is similar to ``reparametrize_as_dtype_state_dict_post_hook`` but uses
75+
FakeTensor and mmap(2) to avoid CPU OOM on colab.
76+
77+
This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e.
78+
79+
>>> m = MyModule()
80+
>>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook)
81+
82+
If the hook is registered per the above process, this hook will be called _after_ the module's
83+
``state_dict`` method is called. The hook will replace all ``NF4Tensor`` instances by unquantizing
84+
them to the original dtype, and optionally offload the restored weight to CPU.
85+
86+
Args:
87+
model (nn.Module): the model to take ``state_dict()`` on
88+
state_dict (Dict[str, Any]): the state dict to modify
89+
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
90+
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
91+
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
92+
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
93+
"""
94+
# Create a state dict of FakeTensors that matches the state_dict
95+
mode = FakeTensorMode()
96+
converter = FakeTensorConverter()
97+
fake_state_dict = OrderedDict()
98+
for k, v in state_dict.items():
99+
if isinstance(v, NF4Tensor):
100+
fake_state_dict[k] = converter.from_real_tensor(mode, v).to(dtype)
101+
else:
102+
fake_state_dict[k] = converter.from_real_tensor(mode, v)
103+
104+
if offload_to_cpu:
105+
fake_state_dict[k] = fake_state_dict[k].cpu()
106+
107+
# Create a state_dict on disk with space reserved for storage bytes
108+
# Then load with mmap and MAP_SHARED (can writeback to disk file)
109+
dest_state_dict_path = "/tmp/fake_state_dict.pt"
110+
with torch.serialization.skip_data(materialize_fake_tensors=True):
111+
torch.save(fake_state_dict, dest_state_dict_path)
112+
with torch.serialization.set_default_mmap_options(mmap.MAP_SHARED):
113+
dest_state_dict = torch.load(dest_state_dict_path, mmap=True, weights_only=True)
114+
115+
# Do D2H and upcast one by one and since dest_state_dict is backed by mmap --> won't OOM
116+
# even when there is no swap space (e.g. colab)
117+
for k in state_dict.keys():
118+
if isinstance(state_dict[k], NF4Tensor):
119+
dest_state_dict[k].copy_(state_dict[k].to(dtype))
120+
else:
121+
dest_state_dict[k].copy_(state_dict[k])
122+
123+
# In place update original state_dict object. Although the private state dict
124+
# post hook supports out of place behavior, the semantic actually buggy. We eventually want
125+
# to use the public state_dict post hook which does not support out of place behavior.
126+
for k in state_dict.keys():
127+
state_dict[k] = dest_state_dict[k]
128+
129+
130+
def _register_reparametrize_state_dict_hooks(
131+
module: nn.Module,
132+
dtype: torch.dtype = torch.bfloat16,
133+
offload_to_cpu: bool = True,
134+
):
135+
"""
136+
Register the reparametrize state dict hooks to the module and its submodules.
137+
138+
This function is a wrapper that is meant to toggle between the low_cpu_ram
139+
and regular versions of the ``reparametrize_as_dtype`` state dict hooks.
140+
141+
Args:
142+
module (nn.Module): the module to register the hooks to.
143+
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
144+
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
145+
146+
Raises:
147+
RuntimeError: If the low RAM reparametrize hook is used on Windows or an incompatible torch version.
148+
"""
149+
if _use_low_cpu_ram:
150+
if torch.__version__ < "2.5.0.dev20240906":
151+
raise RuntimeError(
152+
"Low RAM reparametrize_as_dtype_state_dict_post_hook requires PyTorch 2.5.0.dev20240906 or later."
153+
)
154+
elif sys.platform == "win32":
155+
# mmap.MAP_SHARED is not supported on Windows but this change targets colab.
156+
raise RuntimeError(
157+
"Low RAM reparametrize_as_dtype_state_dict_post_hook is not supported on Windows."
158+
)
159+
else:
160+
hook = _low_ram_reparametrize_as_dtype_state_dict_post_hook
161+
else:
162+
hook = reparametrize_as_dtype_state_dict_post_hook
163+
module._register_state_dict_hook(
164+
partial(hook, dtype=dtype, offload_to_cpu=offload_to_cpu)
165+
)

torchtune/modules/peft/_utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,21 +243,28 @@ def get_merged_lora_ckpt(
243243
for module in lora_modules:
244244
lora_a_weight = state_dict[f"{module}.lora_a.weight"]
245245
lora_b_weight = state_dict[f"{module}.lora_b.weight"]
246-
base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype)
247246
lora_magnitude = state_dict.get(f"{module}.magnitude", None)
248247

249-
lora_weight = (alpha / rank) * lora_b_weight @ lora_a_weight
250-
merged_weight = base_weight + lora_weight
248+
# If magnitude is present, calculate merged DoRA weight
251249
if lora_magnitude is not None:
250+
base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype)
251+
252+
lora_weight = (alpha / rank) * lora_b_weight @ lora_a_weight
253+
merged_weight = base_weight + lora_weight
252254
weight_norm = torch.linalg.norm(base_weight + lora_weight, dim=1)
253255
mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1)
254256
merged_weight *= mag_norm_scale
255-
state_dict[f"{module}.weight"] = merged_weight
257+
state_dict[f"{module}.weight"] = merged_weight
258+
del state_dict[f"{module}.magnitude"]
259+
260+
# Otherwise it is just vanilla LoRA
261+
else:
262+
state_dict[f"{module}.weight"] += (
263+
(alpha / rank) * lora_b_weight @ lora_a_weight
264+
)
256265

257266
del state_dict[f"{module}.lora_a.weight"]
258267
del state_dict[f"{module}.lora_b.weight"]
259-
if lora_magnitude is not None:
260-
del state_dict[f"{module}.magnitude"]
261268

262269
return state_dict
263270

0 commit comments

Comments
 (0)