|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import mmap |
| 8 | +import sys |
| 9 | +from collections import OrderedDict |
| 10 | +from functools import partial |
7 | 11 | from typing import Any, Dict, Tuple |
8 | 12 |
|
9 | 13 | import torch |
10 | 14 |
|
11 | 15 | import torch.nn as nn |
| 16 | +from torch._subclasses.fake_tensor import FakeTensorConverter, FakeTensorMode |
12 | 17 | from torchao.dtypes.nf4tensor import NF4Tensor |
13 | 18 |
|
| 19 | +_use_low_cpu_ram: bool = False |
| 20 | + |
14 | 21 |
|
15 | 22 | def reparametrize_as_dtype_state_dict_post_hook( |
16 | 23 | model: nn.Module, |
@@ -48,3 +55,111 @@ def reparametrize_as_dtype_state_dict_post_hook( |
48 | 55 | state_dict[k] = v.to(dtype) |
49 | 56 | if offload_to_cpu: |
50 | 57 | state_dict[k] = state_dict[k].cpu() |
| 58 | + |
| 59 | + |
| 60 | +def _low_ram_reparametrize_as_dtype_state_dict_post_hook( |
| 61 | + model: nn.Module, |
| 62 | + state_dict: Dict[str, Any], |
| 63 | + *args: Tuple[Any, ...], |
| 64 | + dtype: torch.dtype = torch.bfloat16, |
| 65 | + offload_to_cpu: bool = True, |
| 66 | + **kwargs: Dict[Any, Any], |
| 67 | +): |
| 68 | + """ |
| 69 | + A state_dict hook that replaces NF4 tensors with their restored |
| 70 | + higher-precision weight and optionally offloads the restored weight to CPU. |
| 71 | + Use this hook to avoid increased peak GPU memory usage during checkpoint |
| 72 | + save when training with QLoRA. |
| 73 | +
|
| 74 | + This hook is similar to ``reparametrize_as_dtype_state_dict_post_hook`` but uses |
| 75 | + FakeTensor and mmap(2) to avoid CPU OOM on colab. |
| 76 | +
|
| 77 | + This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e. |
| 78 | +
|
| 79 | + >>> m = MyModule() |
| 80 | + >>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook) |
| 81 | +
|
| 82 | + If the hook is registered per the above process, this hook will be called _after_ the module's |
| 83 | + ``state_dict`` method is called. The hook will replace all ``NF4Tensor`` instances by unquantizing |
| 84 | + them to the original dtype, and optionally offload the restored weight to CPU. |
| 85 | +
|
| 86 | + Args: |
| 87 | + model (nn.Module): the model to take ``state_dict()`` on |
| 88 | + state_dict (Dict[str, Any]): the state dict to modify |
| 89 | + *args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook. |
| 90 | + dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``. |
| 91 | + offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``. |
| 92 | + **kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook. |
| 93 | + """ |
| 94 | + # Create a state dict of FakeTensors that matches the state_dict |
| 95 | + mode = FakeTensorMode() |
| 96 | + converter = FakeTensorConverter() |
| 97 | + fake_state_dict = OrderedDict() |
| 98 | + for k, v in state_dict.items(): |
| 99 | + if isinstance(v, NF4Tensor): |
| 100 | + fake_state_dict[k] = converter.from_real_tensor(mode, v).to(dtype) |
| 101 | + else: |
| 102 | + fake_state_dict[k] = converter.from_real_tensor(mode, v) |
| 103 | + |
| 104 | + if offload_to_cpu: |
| 105 | + fake_state_dict[k] = fake_state_dict[k].cpu() |
| 106 | + |
| 107 | + # Create a state_dict on disk with space reserved for storage bytes |
| 108 | + # Then load with mmap and MAP_SHARED (can writeback to disk file) |
| 109 | + dest_state_dict_path = "/tmp/fake_state_dict.pt" |
| 110 | + with torch.serialization.skip_data(materialize_fake_tensors=True): |
| 111 | + torch.save(fake_state_dict, dest_state_dict_path) |
| 112 | + with torch.serialization.set_default_mmap_options(mmap.MAP_SHARED): |
| 113 | + dest_state_dict = torch.load(dest_state_dict_path, mmap=True, weights_only=True) |
| 114 | + |
| 115 | + # Do D2H and upcast one by one and since dest_state_dict is backed by mmap --> won't OOM |
| 116 | + # even when there is no swap space (e.g. colab) |
| 117 | + for k in state_dict.keys(): |
| 118 | + if isinstance(state_dict[k], NF4Tensor): |
| 119 | + dest_state_dict[k].copy_(state_dict[k].to(dtype)) |
| 120 | + else: |
| 121 | + dest_state_dict[k].copy_(state_dict[k]) |
| 122 | + |
| 123 | + # In place update original state_dict object. Although the private state dict |
| 124 | + # post hook supports out of place behavior, the semantic actually buggy. We eventually want |
| 125 | + # to use the public state_dict post hook which does not support out of place behavior. |
| 126 | + for k in state_dict.keys(): |
| 127 | + state_dict[k] = dest_state_dict[k] |
| 128 | + |
| 129 | + |
| 130 | +def _register_reparametrize_state_dict_hooks( |
| 131 | + module: nn.Module, |
| 132 | + dtype: torch.dtype = torch.bfloat16, |
| 133 | + offload_to_cpu: bool = True, |
| 134 | +): |
| 135 | + """ |
| 136 | + Register the reparametrize state dict hooks to the module and its submodules. |
| 137 | +
|
| 138 | + This function is a wrapper that is meant to toggle between the low_cpu_ram |
| 139 | + and regular versions of the ``reparametrize_as_dtype`` state dict hooks. |
| 140 | +
|
| 141 | + Args: |
| 142 | + module (nn.Module): the module to register the hooks to. |
| 143 | + dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``. |
| 144 | + offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``. |
| 145 | +
|
| 146 | + Raises: |
| 147 | + RuntimeError: If the low RAM reparametrize hook is used on Windows or an incompatible torch version. |
| 148 | + """ |
| 149 | + if _use_low_cpu_ram: |
| 150 | + if torch.__version__ < "2.5.0.dev20240906": |
| 151 | + raise RuntimeError( |
| 152 | + "Low RAM reparametrize_as_dtype_state_dict_post_hook requires PyTorch 2.5.0.dev20240906 or later." |
| 153 | + ) |
| 154 | + elif sys.platform == "win32": |
| 155 | + # mmap.MAP_SHARED is not supported on Windows but this change targets colab. |
| 156 | + raise RuntimeError( |
| 157 | + "Low RAM reparametrize_as_dtype_state_dict_post_hook is not supported on Windows." |
| 158 | + ) |
| 159 | + else: |
| 160 | + hook = _low_ram_reparametrize_as_dtype_state_dict_post_hook |
| 161 | + else: |
| 162 | + hook = reparametrize_as_dtype_state_dict_post_hook |
| 163 | + module._register_state_dict_hook( |
| 164 | + partial(hook, dtype=dtype, offload_to_cpu=offload_to_cpu) |
| 165 | + ) |
0 commit comments