Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a7da42f
Test ghstack
rohan-varma Mar 26, 2024
61562e0
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Mar 26, 2024
c515f67
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Mar 27, 2024
45455f8
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Mar 28, 2024
a80f4d4
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 1, 2024
64bb80c
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 2, 2024
d1a2137
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 3, 2024
c54b158
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 4, 2024
8909ed6
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 4, 2024
8c06d2b
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 5, 2024
0ebd6f6
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 9, 2024
3edf931
Merge branch 'main' of github.com:pytorch/torchtune
rohan-varma Apr 10, 2024
1142f3a
Merge remote-tracking branch 'origin'
rohan-varma Apr 11, 2024
c190177
QLoRA tutorial
rohan-varma Apr 11, 2024
0064014
upd
rohan-varma Apr 11, 2024
f5be68a
Upd
rohan-varma Apr 11, 2024
df2b1a0
upd
rohan-varma Apr 12, 2024
a0847cb
Merge branch 'main' of github.com:pytorch/torchtune into qlora_tutorial
rohan-varma Apr 12, 2024
cc1d4af
QLoRA tutorial
rohan-varma Apr 12, 2024
cba7177
upd
rohan-varma Apr 12, 2024
f85f4a6
Upd
rohan-varma Apr 13, 2024
a04984c
upd
rohan-varma Apr 14, 2024
bcc2cbc
upd
rohan-varma Apr 15, 2024
555023f
upd
rohan-varma Apr 15, 2024
1333b0d
upd
rohan-varma Apr 15, 2024
b68b2d4
upd
rohan-varma Apr 15, 2024
c382cf6
upd
rohan-varma Apr 15, 2024
b78d51a
upd
rohan-varma Apr 15, 2024
11042e7
upd
rohan-varma Apr 15, 2024
b1e5cc8
upd
rohan-varma Apr 15, 2024
49d13d3
Upd
rohan-varma Apr 15, 2024
fa6b0a3
upd
rohan-varma Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/_static/img/qlora_experiment.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
275 changes: 275 additions & 0 deletions docs/source/examples/qlora_finetune.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
.. _qlora_finetune_label:

=============================
Finetuning Llama2 with QLoRA
=============================

In this tutorial, we'll learn about `QLoRA <https://arxiv.org/abs/2305.14314>`_, an enhancement on top of
`LoRA <https://arxiv.org/abs/2106.09685>`_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll
walk through how QLoRA can be utilized within TorchTune to finetune a Llama2-7b model in < 10 GB of memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you should find and replace TorchTune -> torchtune as we have done it everywhere else since after you first opened this PR

It is highly recommended to first develop an understanding of :ref:`LoRA finetuning in TorchTune<lora_finetune_label>`.


.. grid:: 2

.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn

* How QLoRA saves memory over LoRA finetuning
* An overview of QLoRA in TorchTune
* How to run a QLoRA finetune in TorchTune

.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites

* Be familiar with :ref:`TorchTune<overview_label>`
* Make sure to :ref:`install TorchTune<install_label>`
* Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>`
* Be familiar with :ref:`LoRA in torchtune<lora_finetune_label>`

What is QLoRA?
---------------

`QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable further
memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are
low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of
the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and
the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or bf16), and

therefore activations and intermediate gradients computed are in fp32/bf16.

QLoRA further quantizes the base model parameters into a bespoke 4-bit NormalFloat (NF4) data type, resulting in 4-8x less parameter memory usage while
largely retaining model accuracy. As a result, the vast majority of parameters only take up 4 bits (as opposed to 16 or 32 bits by bf16/fp32 dtypes). This
quantization is done through the method highlighted in the original `QLoRA paper <https://arxiv.org/abs/2305.14314>`_. Adapter
parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve
accuracy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that'd be nice: basically take the diagram in the LoRA tutorial demonstrating full finetune -> LoRA and add one more for LoRA -> QLoRA. (The diagrams take a bit of time so I feel this is more of a nice-to-have at this point.) But if you're interested let me know and I can dig up the original

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely interested! Will punt this out to a follow up though.


The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel like you're linking to the paper too many times, usually I just do it once upon introduction

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah, I think this is fine. you don't know which paragraph someone will start reading at if they're skimming or jumping around

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should link every other reference to the paper to keep people on their toes.

type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. TorchTune uses
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth giving some more details on either of these two optimizations, or do you think it's too in the weeds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO its too in the weeds and not worth it to just re explain if its in the paper. Can add a line directing folks to the paper, but IMO its already clear enough to read the paper for this sort of detail

the `NF4Tensor <https://github.com/pytorch-labs/ao/blob/b9beaf351e27133d189b57d6fa725b1a7824a457/torchao/dtypes/nf4tensor.py#L153>`_ abstraction from the `TorchAO library <https://github.com/pytorch-labs/ao>`_ to build QLoRA components as specified in the paper.
`TorchAO library <https://github.com/pytorch-labs/ao>`_ is a PyTorch-native library that allows you to quantize and prune your models.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we want TorchAO -> torchao



.. _qlora_core_highlevel:

Using QLoRA to save memory
----------------------------------------

In this section, we'll overview how to apply QLoRA to a :class:`~torchtune.modules.peft.LoRALinear` layer in TorchTune. For a deep dive into details on QLoRA in torchtune and underlying abstractions,
please see the :ref:`QLoRA in torchtune deepdive <qlora_deepdive_label>` section of this tutorial.

A core idea of QLoRA is the distinction between compute and storage datatypes (dtypes). Specifically, QLoRA stores base model parameters in 4-bit precision (i.e. the storage dtype), and runs
computation in an original higher precision (the compute dtype), generally either fp32 or bf16. As a first step, QLoRA needs to quantize these base model parameters to 4-bit precision
and store them.

To quantize a :class:`~torchtune.modules.peft.LoRALinear` layer in the QLoRA style, simply pass in the ``quantize_base`` flag as ``True`` into :class:`~torchtune.modules.peft.LoRALinear`. This flag
will result in base model weights being quantized and backed by the ``NF4Tensor`` dtype. Forward passes will also be automatically handled to work with the ``NF4Tensor`` dtype,
specifically, the ``NF4`` base weight will be de-quantized to the compute precision, activation will be computed, and only the 4-bit parameter will be stored for gradient computation
in the backward pass, avoiding extra memory usage that would be incurred by storing the higher precision compute dtype.

Here's an example of creating a quantized ``LoRALinear`` layer in comparison to an unquantized ``LoRALinear`` layer. As we can see, the quantized layer consumes
~8x less memory than the unquantized counterpart.

.. code-block:: python

import torch
from torchtune.modules.peft import LoRALinear

torch.set_default_device("cuda")
qlora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=True)
print(torch.cuda.memory_allocated()) # 177,152 bytes
del qlora_linear
torch.cuda.empty_cache()
lora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=False)
print(torch.cuda.memory_allocated()) # 1,081,344 bytes


Using QLoRA in TorchTune
----------------------------

We'll now cover how you can initialize a QLoRA-enabled Llama2-7b model as well as some details around
checkpointing with QLoRA.

With TorchTune, you can use a simple builder similar to the LoRA builder (:code:`lora_llama_2_7b`) to apply QLoRA to Llama2 models. Here's a simple example of
initializing a Llama2-7b model with QLoRA enabled:

.. code-block:: python

from torchtune.models.llama2 import qlora_llama2_7b

qlora_model = qlora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])

Under the hood, this will apply LoRA to the ``q_proj`` and ``v_proj`` matrices in all attention layers, and further quantize the base parameters
in these matrices to the ``NF4`` dtype. Note that quantization of base model parameters is only applied to layers that are configured to have
LoRA adapters added. For example, in this case, ``k_proj`` and ``output_proj`` in the attention layers don't have LoRA applied to them, so their
base model parameters are not quantized. We can see this by printing the base model parameter dtypes for a particular attention layer:

.. code-block:: python

attn = qlora_model.layers[0].attn
print(type(attn.q_proj.weight)) # <class 'torchao.dtypes.nf4tensor.NF4Tensor'>
print(type(attn.k_proj.weight)) # <class 'torch.nn.parameter.Parameter'>


Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models.
To integrate well with TorchTune's :ref:`checkpointing <checkpointing_label>`, we need to convert ``NF4Tensors`` back to their
original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within
TorchTune and beyond (i.e. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. Also might remove checkpoint conversion since idk that it's really an ecosystem thing in the same way the other items are

Suggested change
TorchTune and beyond (i.e. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done
TorchTune and beyond (e.g. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done

in a typical LoRA training flow.

To achieve this, when using TorchTune's ``qlora_llama2_7b`` builder, we automatically register a hook, :code:`reparametrize_as_dtype_state_dict_post_hook`,
that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these
converted tensors to the CPU. This offloading is to avoid peaking memory by maintaining an entire bf16/fp32 copy of the ``state_dict``
on GPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this sentence could be a little unclear, seemingly implying that the way we avoid peaking memory is by maintaining an entire bf16/fp32 copy on GPU.




Putting it all together: QLoRA finetune
-----------------------------------------

Putting it all together, we can now finetune a model using TorchTune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_,
with a `QLoRA configuration <https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_.

Make sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions<download_llama_label>`.
You can then run the following command to perform a QLoRA finetune of Llama2-7B using the Alpaca dataset on a single GPU.

.. code-block:: bash

tune run lora_finetune_single_device --config llama2/7B_qlora_single_device.yaml

.. note::
Make sure to correctly point to the location of your Llama2 weights and tokenizer. This can be done
either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path`
or by directly modifying the :code:`7B_qlora_single_device.yaml` file. See our :ref:`config_tutorial_label`
for more details on how you can easily clone and modify TorchTune configs.

By default, this run should log peak memory stats at model initialization time and every 100
iterations during training. Let's understand the memory savings enabled by QLoRA on top of LoRA training. LoRA training
can be run as follows:

.. code-block:: bash

tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml

You should see the memory usage printed out during model initialization and training. An example log for LoRA model initialization is as follows:

.. code-block:: python

Memory Stats after model init::
GPU peak memory allocation: 13.96 GB
GPU peak memory reserved: 13.98 GB
GPU peak memory active: 13.96 GB

The following table compares the QLoRA's memory reserved during model initialization and training against vanilla LoRA's.
We can see that QLoRA reduces peak memory by about 35% during model initialization, and about 40% during model training:

================== ================================== ================================
Finetuning method Peak memory reserved, model init Peak memory reserved, training
================== ================================== ================================
LoRA 13.98 GB 15.57 GB
QLoRA 9.13 GB 9.29 GB
================== ================================== ================================

From the logs, one can see that the out-of-the-box training performance is quite slow, slower than 1 iteration per
second:

.. code-block:: python

1|149|Loss: 0.9157477021217346: 1%| | 149/25880 [02:08<6:14:19, 1.15it/s

To speed things up, we can leverage ``torch.compile`` to compile our model and run the compiled result. To work with
QLoRA training, a nightly build of PyTorch must be used. To update PyTorch to the latest nightly,
please see `the installation instructions <https://pytorch.org/get-started/locally/>`_. Once updated,
you can specify the compile flag as ``True`` via a config override:

.. code-block:: bash

tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml compile=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml compile=True
tune run lora_finetune_single_device --config llama2/7B_qlora_single_device.yaml compile=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh shoot, .yaml is not correct we need to remove that as well.


From the logs, we can see about a 200% speed up (after a few hundred iterations once the training has stabilized):

.. code-block:: python

1|228|Loss: 0.8158286809921265: 1%| | 228/25880 [11:59<1:48:16, 3.95it/s

A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this now that you've added the legend

Suggested change
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve).
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below.


.. image:: /_static/img/qlora_experiment.png
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw you can also add explicit labels to the two lines in the figure, e.g. iconic-pasma-57 -> LoRA and azure-bird-56 -> QLoRA

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be annoying, but can you filter the x-axis to [0, 1000] or something in wandb and reupload? Otherwise it looks weird that one is running longer.


.. note::
The above figure was generated with W&B. You can use TorchTune's :class:`~torchtune.utils.metric_logging.WandBLogger`
to generate similar loss curves, but you will need to install W&B and setup an account separately.

As an exercise, you can also try running some evaluation tasks or manually inspecting generations
output by your saved checkpoints (which can be found in :code:`output_dir`).

.. _qlora_deepdive_label:

This deep-dive section resumes from the :ref:`Using QLoRA to save memory<qlora_core_highlevel>` portion of this tutorial and dives into how quantization is done with ``NF4Tensor`` and handled appropriately in the forward pass.

First, we'll begin with
a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_label>` and augmented to support quantization:

.. code-block:: python
:emphasize-lines: 3, 13, 19, 20, 39, 40, 41

from torch import nn, Tensor
import torch.nn.functional as F
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4

class LoRALinear(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float,
quantize_base: bool
):
# These are the weights from the original pretrained model
self.linear = nn.Linear(in_dim, out_dim, bias=False)
self.linear_weight = self.linear.weight
# Use TorchAO's to_nf4 API to quantize the base weight if needed.
if quantize_base:
self.linear_weight = to_nf4(self.linear_weight)
# These are the new LoRA params. In general rank << in_dim, out_dim
self.lora_a = nn.Linear(in_dim, rank, bias=False)
self.lora_b = nn.Linear(rank, out_dim, bias=False)

# Rank and alpha are commonly-tuned hyperparameters
self.rank = rank
self.alpha = alpha

# Most implementations also include some dropout
self.dropout = nn.Dropout(p=dropout)

# The original params are frozen, and only LoRA params are trainable.
self.linear.weight.requires_grad = False
self.lora_a.weight.requires_grad = True
self.lora_b.weight.requires_grad = True

def forward(self, x: Tensor) -> Tensor:
# frozen_out would be the output of the original model
if quantize_base:
# Call into TorchAO's linear_nf4 to run linear forward pass w/quantized weight.
frozen_out = linear_nf4(x, self.weight)
else:
frozen_out = F.linear(x, self.weight)

# lora_a projects inputs down to the much smaller self.rank,
# then lora_b projects back up to the output dimension
lora_out = self.lora_b(self.lora_a(self.dropout(x)))

# Finally, scale by the alpha parameter (normalized by rank)
# and add to the original model's outputs
return frozen_out + (self.alpha / self.rank) * lora_out

As mentioned above, TorchTune takes a dependency on `TorchAO library <https://github.com/pytorch-labs/ao>`_ for some of the core components required for QLoRA. This includes the
``NF4Tensor``, as well as helpful utilities including ``to_nf4`` and ``linear_nf4``.

The key changes on top of the LoRA layer are the usage of the ``to_nf4`` and ``linear_nf4`` APIs.

``to_nf4`` accepts an unquantized (bf16 or fp32) tensor and produces an ``NF4`` representation of the weight. See the `implementation <https://github.com/pytorch-labs/ao/blob/c40358072f99b50cd7e58ec11e0e8d90440e3e25/torchao/dtypes/nf4tensor.py#L587>`_ of ``to_nf4`` for more details.
``linear_nf4`` handles the forward pass and autograd when running with quantized base model weights. It computes the forward pass as a regular
``F.linear`` with the incoming activation and unquantized weight. The quantized weight is saved for backward, as opposed to the unquantized version of the weight, to avoid extra
memory usage due to storing higher precision variables to compute gradients in the backward pass. See `linear_nf4 <https://github.com/pytorch-labs/ao/blob/main/torchao/dtypes/nf4tensor.py#L577>`_ for more details.

In the next section, we'll learn about how to use QLoRA in TorchTune to build a QLoRA quantized Llama2-7b model, as well as some nuances around
checkpointing that are important to be aware of to avoid spiking memory usage.
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ torchtune tutorials.
:link: examples/lora_finetune.html
:tags: finetuning,llama2,lora

.. customcarditem::
:header: Understanding QLoRA in TorchTune
:card_description: Using QLoRA to quantize base model weights and maximize memory savings
:image: _static/img/generic-pytorch-logo.png
:link: examples/qlora_finetune.html
:tags: finetuning,llama2

.. customcardend::


Expand Down Expand Up @@ -92,6 +99,7 @@ torchtune tutorials.

examples/finetune_llm
examples/lora_finetune
examples/qlora_finetune

.. toctree::
:glob:
Expand Down
1 change: 1 addition & 0 deletions docs/source/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Excited? To get started, checkout some of our tutorials, including:

- our :ref:`full finetuning tutorial <finetune_llama_label>` to get started and finetune your first LLM using torchtune.
- our :ref:`LoRA tutorial <lora_finetune_label>` to learn about parameter-efficient finetuning with torchtune.
- our :ref:`QLoRA tutorial <qlora_finetune_label>` to attain maximal memory efficiency with torchtune.

Key Concepts
------------
Expand Down