diff --git a/docs/source/_static/img/qlora_exp.png b/docs/source/_static/img/qlora_exp.png new file mode 100644 index 0000000000..f75828fac4 Binary files /dev/null and b/docs/source/_static/img/qlora_exp.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index ae80855a47..927ad5b9f5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -58,6 +58,13 @@ torchtune tutorials. :link: tutorials/lora_finetune.html :tags: finetuning,llama2,lora +.. customcarditem:: + :header: Understanding QLoRA in TorchTune + :card_description: Using QLoRA to quantize base model weights and maximize memory savings + :image: _static/img/generic-pytorch-logo.png + :link: examples/qlora_finetune.html + :tags: finetuning,llama2 + .. customcarditem:: :header: End-to-End Workflow with torchtune :card_description: Train, Evaluate, Quantize and then Generate with your LLM. @@ -91,6 +98,7 @@ torchtune tutorials. :hidden: tutorials/lora_finetune + tutorials/qlora_finetune tutorials/first_finetune_tutorial tutorials/e2e_flow diff --git a/docs/source/overview.rst b/docs/source/overview.rst index 5cf8ae9a34..4d483cb2af 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -30,6 +30,7 @@ Excited? To get started, checkout some of our tutorials, including: - our :ref:`full finetuning tutorial ` to get started and finetune your first LLM using torchtune. - our :ref:`LoRA tutorial ` to learn about parameter-efficient finetuning with torchtune. +- our :ref:`QLoRA tutorial ` to attain maximal memory efficiency with torchtune. Key Concepts ------------ diff --git a/docs/source/tutorials/qlora_finetune.rst b/docs/source/tutorials/qlora_finetune.rst new file mode 100644 index 0000000000..963b12fa51 --- /dev/null +++ b/docs/source/tutorials/qlora_finetune.rst @@ -0,0 +1,277 @@ +.. _qlora_finetune_label: + +============================= +Finetuning Llama2 with QLoRA +============================= + +In this tutorial, we'll learn about `QLoRA `_, an enhancement on top of +`LoRA `_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll +walk through how QLoRA can be utilized within torchtune to finetune a Llama2-7b model in < 10 GB of memory. +It is highly recommended to first develop an understanding of :ref:`LoRA finetuning in torchtune`. + + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * How QLoRA saves memory over LoRA finetuning + * An overview of QLoRA in torchtune + * How to run a QLoRA finetune in torchtune + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * Be familiar with :ref:`torchtune` + * Make sure to :ref:`install torchtune` + * Make sure you have downloaded the :ref:`Llama2-7B model weights` + * Be familiar with :ref:`LoRA in torchtune` + +What is QLoRA? +--------------- + +`QLoRA `_ builds on top of `LoRA `_ to enable further +memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are +low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of +the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or bf16), and +therefore activations and intermediate gradients computed are in fp32/bf16. + +QLoRA further quantizes the base model parameters into a bespoke 4-bit NormalFloat (NF4) data type, resulting in 4-8x less parameter memory usage while +largely retaining model accuracy. As a result, the vast majority of parameters only take up 4 bits (as opposed to 16 or 32 bits by bf16/fp32 dtypes). This +quantization is done through the method highlighted in the original `QLoRA paper `_. Adapter +parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve +accuracy. + +The `QLoRA paper `_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat +type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. torchtune uses +the `NF4Tensor `_ abstraction from the `torchao library `_ to build QLoRA components as specified in the paper. +`torchao library `_ is a PyTorch-native library that allows you to quantize and prune your models. + + +.. _qlora_core_highlevel: + +Using QLoRA to save memory +---------------------------------------- + +In this section, we'll overview how to apply QLoRA to a :class:`~torchtune.modules.peft.LoRALinear` layer in torchtune. For a deep dive into details on QLoRA in torchtune and underlying abstractions, +please see the :ref:`QLoRA in torchtune deepdive ` section of this tutorial. + +A core idea of QLoRA is the distinction between compute and storage datatypes (dtypes). Specifically, QLoRA stores base model parameters in 4-bit precision (i.e. the storage dtype), and runs +computation in an original higher precision (the compute dtype), generally either fp32 or bf16. As a first step, QLoRA needs to quantize these base model parameters to 4-bit precision +and store them. + +To quantize a :class:`~torchtune.modules.peft.LoRALinear` layer in the QLoRA style, simply pass in the ``quantize_base`` flag as ``True`` into :class:`~torchtune.modules.peft.LoRALinear`. This flag +will result in base model weights being quantized and backed by the ``NF4Tensor`` dtype. Forward passes will also be automatically handled to work with the ``NF4Tensor`` dtype, +specifically, the ``NF4`` base weight will be de-quantized to the compute precision, activation will be computed, and only the 4-bit parameter will be stored for gradient computation +in the backward pass, avoiding extra memory usage that would be incurred by storing the higher precision compute dtype. + +Here's an example of creating a quantized ``LoRALinear`` layer in comparison to an unquantized ``LoRALinear`` layer. As we can see, the quantized layer consumes +~8x less memory than the unquantized counterpart. + +.. code-block:: python + + import torch + from torchtune.modules.peft import LoRALinear + + torch.set_default_device("cuda") + qlora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=True) + print(torch.cuda.memory_allocated()) # 177,152 bytes + del qlora_linear + torch.cuda.empty_cache() + lora_linear = LoRALinear(512, 512, rank=8, alpha=0.1, quantize_base=False) + print(torch.cuda.memory_allocated()) # 1,081,344 bytes + + +Using QLoRA in torchtune +---------------------------- + +We'll now cover how you can initialize a QLoRA-enabled Llama2-7b model as well as some details around +checkpointing with QLoRA. + +With torchtune, you can use a simple builder similar to the LoRA builder (:code:`lora_llama_2_7b`) to apply QLoRA to Llama2 models. Here's a simple example of +initializing a Llama2-7b model with QLoRA enabled: + +.. code-block:: python + + from torchtune.models.llama2 import qlora_llama2_7b + + qlora_model = qlora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"]) + +Under the hood, this will apply LoRA to the ``q_proj`` and ``v_proj`` matrices in all attention layers, and further quantize the base parameters +in these matrices to the ``NF4`` dtype. Note that quantization of base model parameters is only applied to layers that are configured to have +LoRA adapters added. For example, in this case, ``k_proj`` and ``output_proj`` in the attention layers don't have LoRA applied to them, so their +base model parameters are not quantized. We can see this by printing the base model parameter dtypes for a particular attention layer: + +.. code-block:: python + + attn = qlora_model.layers[0].attn + print(type(attn.q_proj.weight)) # + print(type(attn.k_proj.weight)) # + + +Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models. +To integrate well with torchtune's :ref:`checkpointing `, we need to convert ``NF4Tensors`` back to their +original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within +torchtune and beyond (e.g. post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done +in a typical LoRA training flow. + +To achieve this, when using torchtune's ``qlora_llama2_7b`` builder, we automatically register a hook, :code:`reparametrize_as_dtype_state_dict_post_hook`, +that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these +converted tensors to the CPU. This offloading is to avoid peaking memory; if we did not, we would have to maintain an entire bf16/fp32 copy of the ``state_dict`` +on GPU. + + + +Putting it all together: QLoRA finetune +----------------------------------------- + +Putting it all together, we can now finetune a model using torchtune's `LoRA recipe `_, +with a `QLoRA configuration `_. + +Make sure that you have first downloaded the Llama2 weights and tokenizer by following :ref:`these instructions`. +You can then run the following command to perform a QLoRA finetune of Llama2-7B using the Alpaca dataset on a single GPU. + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama2/7B_qlora_single_device + +.. note:: + Make sure to correctly point to the location of your Llama2 weights and tokenizer. This can be done + either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path` + or by directly modifying the :code:`7B_qlora_single_device.yaml` file. See our :ref:`config_tutorial_label` + for more details on how you can easily clone and modify torchtune configs. + +By default, this run should log peak memory stats at model initialization time and every 100 +iterations during training. Let's understand the memory savings enabled by QLoRA on top of LoRA training. LoRA training +can be run as follows: + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama2/7B_lora_single_device + +You should see the memory usage printed out during model initialization and training. An example log for LoRA model initialization is as follows: + +.. code-block:: python + + Memory Stats after model init:: + GPU peak memory allocation: 13.96 GB + GPU peak memory reserved: 13.98 GB + GPU peak memory active: 13.96 GB + +The following table compares the QLoRA's memory reserved during model initialization and training against vanilla LoRA's. +We can see that QLoRA reduces peak memory by about 35% during model initialization, and about 40% during model training: + +================== ================================== ================================ +Finetuning method Peak memory reserved, model init Peak memory reserved, training +================== ================================== ================================ +LoRA 13.98 GB 15.57 GB +QLoRA 9.13 GB 9.29 GB +================== ================================== ================================ + +From the logs, one can see that the out-of-the-box training performance is quite slow, slower than 1 iteration per +second: + +.. code-block:: python + + 1|149|Loss: 0.9157477021217346: 1%| | 149/25880 [02:08<6:14:19, 1.15it/s + +To speed things up, we can leverage ``torch.compile`` to compile our model and run the compiled result. To work with +QLoRA training, a nightly build of PyTorch must be used. To update PyTorch to the latest nightly, +please see `the installation instructions `_. Once updated, +you can specify the compile flag as ``True`` via a config override: + +.. code-block:: bash + + tune run lora_finetune_single_device --config llama2/7B_qlora_single_device compile=True + +From the logs, we can see about a 200% speed up (after a few hundred iterations once the training has stabilized): + +.. code-block:: python + + 1|228|Loss: 0.8158286809921265: 1%| | 228/25880 [11:59<1:48:16, 3.95it/s + +A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below. + +.. image:: /_static/img/qlora_exp.png + +.. note:: + The above figure was generated with W&B. You can use torchtune's :class:`~torchtune.utils.metric_logging.WandBLogger` + to generate similar loss curves, but you will need to install W&B and setup an account separately. + +As an exercise, you can also try running some evaluation tasks or manually inspecting generations +output by your saved checkpoints (which can be found in :code:`output_dir`). + +In the final section, we'll go over a deep dive on how a QLoRA component can be built from a LoRA component. + +.. _qlora_deepdive_label: + +Deep-dive: Building QLoRA from LoRA +----------------------------------------- + +This deep-dive section resumes from the :ref:`Using QLoRA to save memory` portion of this tutorial and dives into how quantization is done with ``NF4Tensor`` and handled appropriately in the forward pass. + +First, we'll begin with +a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial ` and augmented to support quantization: + +.. code-block:: python + :emphasize-lines: 3, 13, 19, 20, 39, 40, 41 + + from torch import nn, Tensor + import torch.nn.functional as F + from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + class LoRALinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float, + quantize_base: bool + ): + # These are the weights from the original pretrained model + self.linear = nn.Linear(in_dim, out_dim, bias=False) + self.linear_weight = self.linear.weight + # Use torchao's to_nf4 API to quantize the base weight if needed. + if quantize_base: + self.linear_weight = to_nf4(self.linear_weight) + # These are the new LoRA params. In general rank << in_dim, out_dim + self.lora_a = nn.Linear(in_dim, rank, bias=False) + self.lora_b = nn.Linear(rank, out_dim, bias=False) + + # Rank and alpha are commonly-tuned hyperparameters + self.rank = rank + self.alpha = alpha + + # Most implementations also include some dropout + self.dropout = nn.Dropout(p=dropout) + + # The original params are frozen, and only LoRA params are trainable. + self.linear.weight.requires_grad = False + self.lora_a.weight.requires_grad = True + self.lora_b.weight.requires_grad = True + + def forward(self, x: Tensor) -> Tensor: + # frozen_out would be the output of the original model + if quantize_base: + # Call into torchao's linear_nf4 to run linear forward pass w/quantized weight. + frozen_out = linear_nf4(x, self.weight) + else: + frozen_out = F.linear(x, self.weight) + + # lora_a projects inputs down to the much smaller self.rank, + # then lora_b projects back up to the output dimension + lora_out = self.lora_b(self.lora_a(self.dropout(x))) + + # Finally, scale by the alpha parameter (normalized by rank) + # and add to the original model's outputs + return frozen_out + (self.alpha / self.rank) * lora_out + +As mentioned above, torchtune takes a dependency on `torchao library `_ for some of the core components required for QLoRA. This includes the +``NF4Tensor``, as well as helpful utilities including ``to_nf4`` and ``linear_nf4``. + +The key changes on top of the LoRA layer are the usage of the ``to_nf4`` and ``linear_nf4`` APIs. + +``to_nf4`` accepts an unquantized (bf16 or fp32) tensor and produces an ``NF4`` representation of the weight. See the `implementation `_ of ``to_nf4`` for more details. +``linear_nf4`` handles the forward pass and autograd when running with quantized base model weights. It computes the forward pass as a regular +``F.linear`` with the incoming activation and unquantized weight. The quantized weight is saved for backward, as opposed to the unquantized version of the weight, to avoid extra +memory usage due to storing higher precision variables to compute gradients in the backward pass. See `linear_nf4 `_ for more details.