Skip to content

Commit 0914d5c

Browse files
authored
QLoRA tutorial (#693)
1 parent 5b0dc57 commit 0914d5c

File tree

4 files changed

+286
-0
lines changed

4 files changed

+286
-0
lines changed
90.1 KB
Loading

docs/source/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ torchtune tutorials.
5858
:link: tutorials/lora_finetune.html
5959
:tags: finetuning,llama2,lora
6060

61+
.. customcarditem::
62+
:header: Understanding QLoRA in TorchTune
63+
:card_description: Using QLoRA to quantize base model weights and maximize memory savings
64+
:image: _static/img/generic-pytorch-logo.png
65+
:link: examples/qlora_finetune.html
66+
:tags: finetuning,llama2
67+
6168
.. customcarditem::
6269
:header: End-to-End Workflow with torchtune
6370
:card_description: Train, Evaluate, Quantize and then Generate with your LLM.
@@ -91,6 +98,7 @@ torchtune tutorials.
9198
:hidden:
9299

93100
tutorials/lora_finetune
101+
tutorials/qlora_finetune
94102
tutorials/first_finetune_tutorial
95103
tutorials/e2e_flow
96104

docs/source/overview.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Excited? To get started, checkout some of our tutorials, including:
3030

3131
- our :ref:`full finetuning tutorial <finetune_llama_label>` to get started and finetune your first LLM using torchtune.
3232
- our :ref:`LoRA tutorial <lora_finetune_label>` to learn about parameter-efficient finetuning with torchtune.
33+
- our :ref:`QLoRA tutorial <qlora_finetune_label>` to attain maximal memory efficiency with torchtune.
3334

3435
Key Concepts
3536
------------
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
.. _qlora_finetune_label:
2+
3+
=============================
4+
Finetuning Llama2 with QLoRA
5+
=============================
6+
7+
In this tutorial, we'll learn about `QLoRA <https://arxiv.org/abs/2305.14314>`_, an enhancement on top of
8+
`LoRA <https://arxiv.org/abs/2106.09685>`_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll
9+
walk through how QLoRA can be utilized within torchtune to finetune a Llama2-7b model in < 10 GB of memory.
10+
It is highly recommended to first develop an understanding of :ref:`LoRA finetuning in torchtune<lora_finetune_label>`.
11+
12+
13+
.. grid:: 2
14+
15+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
16+
17+
* How QLoRA saves memory over LoRA finetuning
18+
* An overview of QLoRA in torchtune
19+
* How to run a QLoRA finetune in torchtune
20+
21+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
22+
23+
* Be familiar with :ref:`torchtune<overview_label>`
24+
* Make sure to :ref:`install torchtune<install_label>`
25+
* Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`
26+
* Be familiar with :ref:`LoRA in torchtune<lora_finetune_label>`
27+
28+
What is QLoRA?
29+
---------------
30+
31+
`QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable further
32+
memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are
33+
low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of
34+
the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or bf16), and
35+
therefore activations and intermediate gradients computed are in fp32/bf16.
36+
37+
QLoRA further quantizes the base model parameters into a bespoke 4-bit NormalFloat (NF4) data type, resulting in 4-8x less parameter memory usage while
38+
largely retaining model accuracy. As a result, the vast majority of parameters only take up 4 bits (as opposed to 16 or 32 bits by bf16/fp32 dtypes). This
39+
quantization is done through the method highlighted in the original `QLoRA paper <https://arxiv.org/abs/2305.14314>`_. Adapter
40+
parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve
41+
accuracy.
42+
43+
The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat
44+
type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. torchtune uses
45+
the `NF4Tensor <https://github.com/pytorch-labs/ao/blob/b9beaf351e27133d189b57d6fa725b1a7824a457/torchao/dtypes/nf4tensor.py#L153>`_ abstraction from the `torchao library <https://github.com/pytorch-labs/ao>`_ to build QLoRA components as specified in the paper.
46+
`torchao library <https://github.com/pytorch-labs/ao>`_ is a PyTorch-native library that allows you to quantize and prune your models.
47+
48+
49+
.. _qlora_core_highlevel:
50+
51+
Using QLoRA to save memory
52+
----------------------------------------
53+
54+
In this section, we'll overview how to apply QLoRA to a :class:`~torchtune.modules.peft.LoRALinear` layer in torchtune. For a deep dive into details on QLoRA in torchtune and underlying abstractions,
55+
please see the :ref:`QLoRA in torchtune deepdive <qlora_deepdive_label>` section of this tutorial.
56+
57+
A core idea of QLoRA is the distinction between compute and storage datatypes (dtypes). Specifically, QLoRA stores base model parameters in 4-bit precision (i.e. the storage dtype), and runs
58+
computation in an original higher precision (the compute dtype), generally either fp32 or bf16. As a first step, QLoRA needs to quantize these base model parameters to 4-bit precision
59+
and store them.
60+
61+
To quantize a :class:`~torchtune.modules.peft.LoRALinear` layer in the QLoRA style, simply pass in the ``quantize_base`` flag as ``True`` into :class:`~torchtune.modules.peft.LoRALinear`. This flag
62+
will result in base model weights being quantized and backed by the ``NF4Tensor`` dtype. Forward passes will also be automatically handled to work with the ``NF4Tensor`` dtype,
63+
specifically, the ``NF4`` base weight will be de-quantized to the compute precision, activation will be computed, and only the 4-bit parameter will be stored for gradient computation
64+
in the backward pass, avoiding extra memory usage that would be incurred by storing the higher precision compute dtype.
65+
66+
Here's an example of creating a quantized ``LoRALinear`` layer in comparison to an unquantized ``LoRALinear`` layer. As we can see, the quantized layer consumes
67+
~8x less memory than the unquantized counterpart.
68+
69+
.. code-block:: python
70+
71+
import torch
72+
from torchtune.modules.peft import LoRALinear
73+
74+
torch.set_default_device("cuda")
75+
qlora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=True)
76+
print(torch.cuda.memory_allocated()) # 177,152 bytes
77+
del qlora_linear
78+
torch.cuda.empty_cache()
79+
lora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=False)
80+
print(torch.cuda.memory_allocated()) # 1,081,344 bytes
81+
82+
83+
Using QLoRA in torchtune
84+
----------------------------
85+
86+
We'll now cover how you can initialize a QLoRA-enabled Llama2-7b model as well as some details around
87+
checkpointing with QLoRA.
88+
89+
With torchtune, you can use a simple builder similar to the LoRA builder (:code:`lora_llama_2_7b`) to apply QLoRA to Llama2 models. Here's a simple example of
90+
initializing a Llama2-7b model with QLoRA enabled:
91+
92+
.. code-block:: python
93+
94+
from torchtune.models.llama2 import qlora_llama2_7b
95+
96+
qlora_model = qlora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])
97+
98+
Under the hood, this will apply LoRA to the ``q_proj`` and ``v_proj`` matrices in all attention layers, and further quantize the base parameters
99+
in these matrices to the ``NF4`` dtype. Note that quantization of base model parameters is only applied to layers that are configured to have
100+
LoRA adapters added. For example, in this case, ``k_proj`` and ``output_proj`` in the attention layers don't have LoRA applied to them, so their
101+
base model parameters are not quantized. We can see this by printing the base model parameter dtypes for a particular attention layer:
102+
103+
.. code-block:: python
104+
105+
attn = qlora_model.layers[0].attn
106+
print(type(attn.q_proj.weight)) # <class 'torchao.dtypes.nf4tensor.NF4Tensor'>
107+
print(type(attn.k_proj.weight)) # <class 'torch.nn.parameter.Parameter'>
108+
109+
110+
Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models.
111+
To integrate well with torchtune's :ref:`checkpointing <checkpointing_label>`, we need to convert ``NF4Tensors`` back to their
112+
original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within
113+
torchtune and beyond (e.g. post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done
114+
in a typical LoRA training flow.
115+
116+
To achieve this, when using torchtune's ``qlora_llama2_7b`` builder, we automatically register a hook, :code:`reparametrize_as_dtype_state_dict_post_hook`,
117+
that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these
118+
converted tensors to the CPU. This offloading is to avoid peaking memory; if we did not, we would have to maintain an entire bf16/fp32 copy of the ``state_dict``
119+
on GPU.
120+
121+
122+
123+
Putting it all together: QLoRA finetune
124+
-----------------------------------------
125+
126+
Putting it all together, we can now finetune a model using torchtune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_,
127+
with a `QLoRA configuration <https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_.
128+
129+
Make sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.
130+
You can then run the following command to perform a QLoRA finetune of Llama2-7B using the Alpaca dataset on a single GPU.
131+
132+
.. code-block:: bash
133+
134+
tune run lora_finetune_single_device --config llama2/7B_qlora_single_device
135+
136+
.. note::
137+
Make sure to correctly point to the location of your Llama2 weights and tokenizer. This can be done
138+
either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`
139+
or by directly modifying the :code:`7B_qlora_single_device.yaml` file. See our :ref:`config_tutorial_label`
140+
for more details on how you can easily clone and modify torchtune configs.
141+
142+
By default, this run should log peak memory stats at model initialization time and every 100
143+
iterations during training. Let's understand the memory savings enabled by QLoRA on top of LoRA training. LoRA training
144+
can be run as follows:
145+
146+
.. code-block:: bash
147+
148+
tune run lora_finetune_single_device --config llama2/7B_lora_single_device
149+
150+
You should see the memory usage printed out during model initialization and training. An example log for LoRA model initialization is as follows:
151+
152+
.. code-block:: python
153+
154+
Memory Stats after model init::
155+
GPU peak memory allocation: 13.96 GB
156+
GPU peak memory reserved: 13.98 GB
157+
GPU peak memory active: 13.96 GB
158+
159+
The following table compares the QLoRA's memory reserved during model initialization and training against vanilla LoRA's.
160+
We can see that QLoRA reduces peak memory by about 35% during model initialization, and about 40% during model training:
161+
162+
================== ================================== ================================
163+
Finetuning method Peak memory reserved, model init Peak memory reserved, training
164+
================== ================================== ================================
165+
LoRA 13.98 GB 15.57 GB
166+
QLoRA 9.13 GB 9.29 GB
167+
================== ================================== ================================
168+
169+
From the logs, one can see that the out-of-the-box training performance is quite slow, slower than 1 iteration per
170+
second:
171+
172+
.. code-block:: python
173+
174+
1|149|Loss: 0.9157477021217346: 1%| | 149/25880 [02:08<6:14:19, 1.15it/s
175+
176+
To speed things up, we can leverage ``torch.compile`` to compile our model and run the compiled result. To work with
177+
QLoRA training, a nightly build of PyTorch must be used. To update PyTorch to the latest nightly,
178+
please see `the installation instructions <https://pytorch.org/get-started/locally/>`_. Once updated,
179+
you can specify the compile flag as ``True`` via a config override:
180+
181+
.. code-block:: bash
182+
183+
tune run lora_finetune_single_device --config llama2/7B_qlora_single_device compile=True
184+
185+
From the logs, we can see about a 200% speed up (after a few hundred iterations once the training has stabilized):
186+
187+
.. code-block:: python
188+
189+
1|228|Loss: 0.8158286809921265: 1%| | 228/25880 [11:59<1:48:16, 3.95it/s
190+
191+
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below.
192+
193+
.. image:: /_static/img/qlora_exp.png
194+
195+
.. note::
196+
The above figure was generated with W&B. You can use torchtune's :class:`~torchtune.utils.metric_logging.WandBLogger`
197+
to generate similar loss curves, but you will need to install W&B and setup an account separately.
198+
199+
As an exercise, you can also try running some evaluation tasks or manually inspecting generations
200+
output by your saved checkpoints (which can be found in :code:`output_dir`).
201+
202+
In the final section, we'll go over a deep dive on how a QLoRA component can be built from a LoRA component.
203+
204+
.. _qlora_deepdive_label:
205+
206+
Deep-dive: Building QLoRA from LoRA
207+
-----------------------------------------
208+
209+
This deep-dive section resumes from the :ref:`Using QLoRA to save memory<qlora_core_highlevel>` portion of this tutorial and dives into how quantization is done with ``NF4Tensor`` and handled appropriately in the forward pass.
210+
211+
First, we'll begin with
212+
a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_label>` and augmented to support quantization:
213+
214+
.. code-block:: python
215+
:emphasize-lines: 3, 13, 19, 20, 39, 40, 41
216+
217+
from torch import nn, Tensor
218+
import torch.nn.functional as F
219+
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
220+
221+
class LoRALinear(nn.Module):
222+
def __init__(
223+
self,
224+
in_dim: int,
225+
out_dim: int,
226+
rank: int,
227+
alpha: float,
228+
dropout: float,
229+
quantize_base: bool
230+
):
231+
# These are the weights from the original pretrained model
232+
self.linear = nn.Linear(in_dim, out_dim, bias=False)
233+
self.linear_weight = self.linear.weight
234+
# Use torchao's to_nf4 API to quantize the base weight if needed.
235+
if quantize_base:
236+
self.linear_weight = to_nf4(self.linear_weight)
237+
# These are the new LoRA params. In general rank << in_dim, out_dim
238+
self.lora_a = nn.Linear(in_dim, rank, bias=False)
239+
self.lora_b = nn.Linear(rank, out_dim, bias=False)
240+
241+
# Rank and alpha are commonly-tuned hyperparameters
242+
self.rank = rank
243+
self.alpha = alpha
244+
245+
# Most implementations also include some dropout
246+
self.dropout = nn.Dropout(p=dropout)
247+
248+
# The original params are frozen, and only LoRA params are trainable.
249+
self.linear.weight.requires_grad = False
250+
self.lora_a.weight.requires_grad = True
251+
self.lora_b.weight.requires_grad = True
252+
253+
def forward(self, x: Tensor) -> Tensor:
254+
# frozen_out would be the output of the original model
255+
if quantize_base:
256+
# Call into torchao's linear_nf4 to run linear forward pass w/quantized weight.
257+
frozen_out = linear_nf4(x, self.weight)
258+
else:
259+
frozen_out = F.linear(x, self.weight)
260+
261+
# lora_a projects inputs down to the much smaller self.rank,
262+
# then lora_b projects back up to the output dimension
263+
lora_out = self.lora_b(self.lora_a(self.dropout(x)))
264+
265+
# Finally, scale by the alpha parameter (normalized by rank)
266+
# and add to the original model's outputs
267+
return frozen_out + (self.alpha / self.rank) * lora_out
268+
269+
As mentioned above, torchtune takes a dependency on `torchao library <https://github.com/pytorch-labs/ao>`_ for some of the core components required for QLoRA. This includes the
270+
``NF4Tensor``, as well as helpful utilities including ``to_nf4`` and ``linear_nf4``.
271+
272+
The key changes on top of the LoRA layer are the usage of the ``to_nf4`` and ``linear_nf4`` APIs.
273+
274+
``to_nf4`` accepts an unquantized (bf16 or fp32) tensor and produces an ``NF4`` representation of the weight. See the `implementation <https://github.com/pytorch-labs/ao/blob/c40358072f99b50cd7e58ec11e0e8d90440e3e25/torchao/dtypes/nf4tensor.py#L587>`_ of ``to_nf4`` for more details.
275+
``linear_nf4`` handles the forward pass and autograd when running with quantized base model weights. It computes the forward pass as a regular
276+
``F.linear`` with the incoming activation and unquantized weight. The quantized weight is saved for backward, as opposed to the unquantized version of the weight, to avoid extra
277+
memory usage due to storing higher precision variables to compute gradients in the backward pass. See `linear_nf4 <https://github.com/pytorch-labs/ao/blob/main/torchao/dtypes/nf4tensor.py#L577>`_ for more details.

0 commit comments

Comments
 (0)