|
| 1 | +.. _qlora_finetune_label: |
| 2 | + |
| 3 | +============================= |
| 4 | +Finetuning Llama2 with QLoRA |
| 5 | +============================= |
| 6 | + |
| 7 | +In this tutorial, we'll learn about `QLoRA <https://arxiv.org/abs/2305.14314>`_, an enhancement on top of |
| 8 | +`LoRA <https://arxiv.org/abs/2106.09685>`_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll |
| 9 | +walk through how QLoRA can be utilized within torchtune to finetune a Llama2-7b model in < 10 GB of memory. |
| 10 | +It is highly recommended to first develop an understanding of :ref:`LoRA finetuning in torchtune<lora_finetune_label>`. |
| 11 | + |
| 12 | + |
| 13 | +.. grid:: 2 |
| 14 | + |
| 15 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 16 | + |
| 17 | + * How QLoRA saves memory over LoRA finetuning |
| 18 | + * An overview of QLoRA in torchtune |
| 19 | + * How to run a QLoRA finetune in torchtune |
| 20 | + |
| 21 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 22 | + |
| 23 | + * Be familiar with :ref:`torchtune<overview_label>` |
| 24 | + * Make sure to :ref:`install torchtune<install_label>` |
| 25 | + * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>` |
| 26 | + * Be familiar with :ref:`LoRA in torchtune<lora_finetune_label>` |
| 27 | + |
| 28 | +What is QLoRA? |
| 29 | +--------------- |
| 30 | + |
| 31 | +`QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable further |
| 32 | +memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are |
| 33 | +low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of |
| 34 | +the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or bf16), and |
| 35 | +therefore activations and intermediate gradients computed are in fp32/bf16. |
| 36 | + |
| 37 | +QLoRA further quantizes the base model parameters into a bespoke 4-bit NormalFloat (NF4) data type, resulting in 4-8x less parameter memory usage while |
| 38 | +largely retaining model accuracy. As a result, the vast majority of parameters only take up 4 bits (as opposed to 16 or 32 bits by bf16/fp32 dtypes). This |
| 39 | +quantization is done through the method highlighted in the original `QLoRA paper <https://arxiv.org/abs/2305.14314>`_. Adapter |
| 40 | +parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve |
| 41 | +accuracy. |
| 42 | + |
| 43 | +The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat |
| 44 | +type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. torchtune uses |
| 45 | +the `NF4Tensor <https://github.com/pytorch-labs/ao/blob/b9beaf351e27133d189b57d6fa725b1a7824a457/torchao/dtypes/nf4tensor.py#L153>`_ abstraction from the `torchao library <https://github.com/pytorch-labs/ao>`_ to build QLoRA components as specified in the paper. |
| 46 | +`torchao library <https://github.com/pytorch-labs/ao>`_ is a PyTorch-native library that allows you to quantize and prune your models. |
| 47 | + |
| 48 | + |
| 49 | +.. _qlora_core_highlevel: |
| 50 | + |
| 51 | +Using QLoRA to save memory |
| 52 | +---------------------------------------- |
| 53 | + |
| 54 | +In this section, we'll overview how to apply QLoRA to a :class:`~torchtune.modules.peft.LoRALinear` layer in torchtune. For a deep dive into details on QLoRA in torchtune and underlying abstractions, |
| 55 | +please see the :ref:`QLoRA in torchtune deepdive <qlora_deepdive_label>` section of this tutorial. |
| 56 | + |
| 57 | +A core idea of QLoRA is the distinction between compute and storage datatypes (dtypes). Specifically, QLoRA stores base model parameters in 4-bit precision (i.e. the storage dtype), and runs |
| 58 | +computation in an original higher precision (the compute dtype), generally either fp32 or bf16. As a first step, QLoRA needs to quantize these base model parameters to 4-bit precision |
| 59 | +and store them. |
| 60 | + |
| 61 | +To quantize a :class:`~torchtune.modules.peft.LoRALinear` layer in the QLoRA style, simply pass in the ``quantize_base`` flag as ``True`` into :class:`~torchtune.modules.peft.LoRALinear`. This flag |
| 62 | +will result in base model weights being quantized and backed by the ``NF4Tensor`` dtype. Forward passes will also be automatically handled to work with the ``NF4Tensor`` dtype, |
| 63 | +specifically, the ``NF4`` base weight will be de-quantized to the compute precision, activation will be computed, and only the 4-bit parameter will be stored for gradient computation |
| 64 | +in the backward pass, avoiding extra memory usage that would be incurred by storing the higher precision compute dtype. |
| 65 | + |
| 66 | +Here's an example of creating a quantized ``LoRALinear`` layer in comparison to an unquantized ``LoRALinear`` layer. As we can see, the quantized layer consumes |
| 67 | +~8x less memory than the unquantized counterpart. |
| 68 | + |
| 69 | +.. code-block:: python |
| 70 | +
|
| 71 | + import torch |
| 72 | + from torchtune.modules.peft import LoRALinear |
| 73 | +
|
| 74 | + torch.set_default_device("cuda") |
| 75 | + qlora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=True) |
| 76 | + print(torch.cuda.memory_allocated()) # 177,152 bytes |
| 77 | + del qlora_linear |
| 78 | + torch.cuda.empty_cache() |
| 79 | + lora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=False) |
| 80 | + print(torch.cuda.memory_allocated()) # 1,081,344 bytes |
| 81 | +
|
| 82 | +
|
| 83 | +Using QLoRA in torchtune |
| 84 | +---------------------------- |
| 85 | + |
| 86 | +We'll now cover how you can initialize a QLoRA-enabled Llama2-7b model as well as some details around |
| 87 | +checkpointing with QLoRA. |
| 88 | + |
| 89 | +With torchtune, you can use a simple builder similar to the LoRA builder (:code:`lora_llama_2_7b`) to apply QLoRA to Llama2 models. Here's a simple example of |
| 90 | +initializing a Llama2-7b model with QLoRA enabled: |
| 91 | + |
| 92 | +.. code-block:: python |
| 93 | +
|
| 94 | + from torchtune.models.llama2 import qlora_llama2_7b |
| 95 | +
|
| 96 | + qlora_model = qlora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"]) |
| 97 | +
|
| 98 | +Under the hood, this will apply LoRA to the ``q_proj`` and ``v_proj`` matrices in all attention layers, and further quantize the base parameters |
| 99 | +in these matrices to the ``NF4`` dtype. Note that quantization of base model parameters is only applied to layers that are configured to have |
| 100 | +LoRA adapters added. For example, in this case, ``k_proj`` and ``output_proj`` in the attention layers don't have LoRA applied to them, so their |
| 101 | +base model parameters are not quantized. We can see this by printing the base model parameter dtypes for a particular attention layer: |
| 102 | + |
| 103 | +.. code-block:: python |
| 104 | +
|
| 105 | + attn = qlora_model.layers[0].attn |
| 106 | + print(type(attn.q_proj.weight)) # <class 'torchao.dtypes.nf4tensor.NF4Tensor'> |
| 107 | + print(type(attn.k_proj.weight)) # <class 'torch.nn.parameter.Parameter'> |
| 108 | +
|
| 109 | +
|
| 110 | +Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models. |
| 111 | +To integrate well with torchtune's :ref:`checkpointing <checkpointing_label>`, we need to convert ``NF4Tensors`` back to their |
| 112 | +original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within |
| 113 | +torchtune and beyond (e.g. post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done |
| 114 | +in a typical LoRA training flow. |
| 115 | + |
| 116 | +To achieve this, when using torchtune's ``qlora_llama2_7b`` builder, we automatically register a hook, :code:`reparametrize_as_dtype_state_dict_post_hook`, |
| 117 | +that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these |
| 118 | +converted tensors to the CPU. This offloading is to avoid peaking memory; if we did not, we would have to maintain an entire bf16/fp32 copy of the ``state_dict`` |
| 119 | +on GPU. |
| 120 | + |
| 121 | + |
| 122 | + |
| 123 | +Putting it all together: QLoRA finetune |
| 124 | +----------------------------------------- |
| 125 | + |
| 126 | +Putting it all together, we can now finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_, |
| 127 | +with a `QLoRA configuration <https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_. |
| 128 | + |
| 129 | +Make sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`. |
| 130 | +You can then run the following command to perform a QLoRA finetune of Llama2-7B using the Alpaca dataset on a single GPU. |
| 131 | + |
| 132 | +.. code-block:: bash |
| 133 | +
|
| 134 | + tune run lora_finetune_single_device --config llama2/7B_qlora_single_device |
| 135 | +
|
| 136 | +.. note:: |
| 137 | + Make sure to correctly point to the location of your Llama2 weights and tokenizer. This can be done |
| 138 | + either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path` |
| 139 | + or by directly modifying the :code:`7B_qlora_single_device.yaml` file. See our :ref:`config_tutorial_label` |
| 140 | + for more details on how you can easily clone and modify torchtune configs. |
| 141 | + |
| 142 | +By default, this run should log peak memory stats at model initialization time and every 100 |
| 143 | +iterations during training. Let's understand the memory savings enabled by QLoRA on top of LoRA training. LoRA training |
| 144 | +can be run as follows: |
| 145 | + |
| 146 | +.. code-block:: bash |
| 147 | +
|
| 148 | + tune run lora_finetune_single_device --config llama2/7B_lora_single_device |
| 149 | +
|
| 150 | +You should see the memory usage printed out during model initialization and training. An example log for LoRA model initialization is as follows: |
| 151 | + |
| 152 | +.. code-block:: python |
| 153 | +
|
| 154 | + Memory Stats after model init:: |
| 155 | + GPU peak memory allocation: 13.96 GB |
| 156 | + GPU peak memory reserved: 13.98 GB |
| 157 | + GPU peak memory active: 13.96 GB |
| 158 | +
|
| 159 | +The following table compares the QLoRA's memory reserved during model initialization and training against vanilla LoRA's. |
| 160 | +We can see that QLoRA reduces peak memory by about 35% during model initialization, and about 40% during model training: |
| 161 | + |
| 162 | +================== ================================== ================================ |
| 163 | +Finetuning method Peak memory reserved, model init Peak memory reserved, training |
| 164 | +================== ================================== ================================ |
| 165 | +LoRA 13.98 GB 15.57 GB |
| 166 | +QLoRA 9.13 GB 9.29 GB |
| 167 | +================== ================================== ================================ |
| 168 | + |
| 169 | +From the logs, one can see that the out-of-the-box training performance is quite slow, slower than 1 iteration per |
| 170 | +second: |
| 171 | + |
| 172 | +.. code-block:: python |
| 173 | +
|
| 174 | + 1|149|Loss: 0.9157477021217346: 1%| | 149/25880 [02:08<6:14:19, 1.15it/s |
| 175 | +
|
| 176 | +To speed things up, we can leverage ``torch.compile`` to compile our model and run the compiled result. To work with |
| 177 | +QLoRA training, a nightly build of PyTorch must be used. To update PyTorch to the latest nightly, |
| 178 | +please see `the installation instructions <https://pytorch.org/get-started/locally/>`_. Once updated, |
| 179 | +you can specify the compile flag as ``True`` via a config override: |
| 180 | +
|
| 181 | +.. code-block:: bash |
| 182 | +
|
| 183 | + tune run lora_finetune_single_device --config llama2/7B_qlora_single_device compile=True |
| 184 | +
|
| 185 | +From the logs, we can see about a 200% speed up (after a few hundred iterations once the training has stabilized): |
| 186 | +
|
| 187 | +.. code-block:: python |
| 188 | +
|
| 189 | + 1|228|Loss: 0.8158286809921265: 1%| | 228/25880 [11:59<1:48:16, 3.95it/s |
| 190 | +
|
| 191 | +A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below. |
| 192 | +
|
| 193 | +.. image:: /_static/img/qlora_exp.png |
| 194 | +
|
| 195 | +.. note:: |
| 196 | + The above figure was generated with W&B. You can use torchtune's :class:`~torchtune.utils.metric_logging.WandBLogger` |
| 197 | + to generate similar loss curves, but you will need to install W&B and setup an account separately. |
| 198 | +
|
| 199 | +As an exercise, you can also try running some evaluation tasks or manually inspecting generations |
| 200 | +output by your saved checkpoints (which can be found in :code:`output_dir`). |
| 201 | +
|
| 202 | +In the final section, we'll go over a deep dive on how a QLoRA component can be built from a LoRA component. |
| 203 | +
|
| 204 | +.. _qlora_deepdive_label: |
| 205 | +
|
| 206 | +Deep-dive: Building QLoRA from LoRA |
| 207 | +----------------------------------------- |
| 208 | +
|
| 209 | +This deep-dive section resumes from the :ref:`Using QLoRA to save memory<qlora_core_highlevel>` portion of this tutorial and dives into how quantization is done with ``NF4Tensor`` and handled appropriately in the forward pass. |
| 210 | +
|
| 211 | +First, we'll begin with |
| 212 | +a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_label>` and augmented to support quantization: |
| 213 | +
|
| 214 | +.. code-block:: python |
| 215 | + :emphasize-lines: 3, 13, 19, 20, 39, 40, 41 |
| 216 | +
|
| 217 | + from torch import nn, Tensor |
| 218 | + import torch.nn.functional as F |
| 219 | + from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 |
| 220 | +
|
| 221 | + class LoRALinear(nn.Module): |
| 222 | + def __init__( |
| 223 | + self, |
| 224 | + in_dim: int, |
| 225 | + out_dim: int, |
| 226 | + rank: int, |
| 227 | + alpha: float, |
| 228 | + dropout: float, |
| 229 | + quantize_base: bool |
| 230 | + ): |
| 231 | + # These are the weights from the original pretrained model |
| 232 | + self.linear = nn.Linear(in_dim, out_dim, bias=False) |
| 233 | + self.linear_weight = self.linear.weight |
| 234 | + # Use torchao's to_nf4 API to quantize the base weight if needed. |
| 235 | + if quantize_base: |
| 236 | + self.linear_weight = to_nf4(self.linear_weight) |
| 237 | + # These are the new LoRA params. In general rank << in_dim, out_dim |
| 238 | + self.lora_a = nn.Linear(in_dim, rank, bias=False) |
| 239 | + self.lora_b = nn.Linear(rank, out_dim, bias=False) |
| 240 | +
|
| 241 | + # Rank and alpha are commonly-tuned hyperparameters |
| 242 | + self.rank = rank |
| 243 | + self.alpha = alpha |
| 244 | +
|
| 245 | + # Most implementations also include some dropout |
| 246 | + self.dropout = nn.Dropout(p=dropout) |
| 247 | +
|
| 248 | + # The original params are frozen, and only LoRA params are trainable. |
| 249 | + self.linear.weight.requires_grad = False |
| 250 | + self.lora_a.weight.requires_grad = True |
| 251 | + self.lora_b.weight.requires_grad = True |
| 252 | +
|
| 253 | + def forward(self, x: Tensor) -> Tensor: |
| 254 | + # frozen_out would be the output of the original model |
| 255 | + if quantize_base: |
| 256 | + # Call into torchao's linear_nf4 to run linear forward pass w/quantized weight. |
| 257 | + frozen_out = linear_nf4(x, self.weight) |
| 258 | + else: |
| 259 | + frozen_out = F.linear(x, self.weight) |
| 260 | +
|
| 261 | + # lora_a projects inputs down to the much smaller self.rank, |
| 262 | + # then lora_b projects back up to the output dimension |
| 263 | + lora_out = self.lora_b(self.lora_a(self.dropout(x))) |
| 264 | +
|
| 265 | + # Finally, scale by the alpha parameter (normalized by rank) |
| 266 | + # and add to the original model's outputs |
| 267 | + return frozen_out + (self.alpha / self.rank) * lora_out |
| 268 | +
|
| 269 | +As mentioned above, torchtune takes a dependency on `torchao library <https://github.com/pytorch-labs/ao>`_ for some of the core components required for QLoRA. This includes the |
| 270 | +``NF4Tensor``, as well as helpful utilities including ``to_nf4`` and ``linear_nf4``. |
| 271 | +
|
| 272 | +The key changes on top of the LoRA layer are the usage of the ``to_nf4`` and ``linear_nf4`` APIs. |
| 273 | +
|
| 274 | +``to_nf4`` accepts an unquantized (bf16 or fp32) tensor and produces an ``NF4`` representation of the weight. See the `implementation <https://github.com/pytorch-labs/ao/blob/c40358072f99b50cd7e58ec11e0e8d90440e3e25/torchao/dtypes/nf4tensor.py#L587>`_ of ``to_nf4`` for more details. |
| 275 | +``linear_nf4`` handles the forward pass and autograd when running with quantized base model weights. It computes the forward pass as a regular |
| 276 | +``F.linear`` with the incoming activation and unquantized weight. The quantized weight is saved for backward, as opposed to the unquantized version of the weight, to avoid extra |
| 277 | +memory usage due to storing higher precision variables to compute gradients in the backward pass. See `linear_nf4 <https://github.com/pytorch-labs/ao/blob/main/torchao/dtypes/nf4tensor.py#L577>`_ for more details. |
0 commit comments