-
Notifications
You must be signed in to change notification settings - Fork 693
QLoRA tutorial #693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QLoRA tutorial #693
Changes from 23 commits
a7da42f
61562e0
c515f67
45455f8
a80f4d4
64bb80c
d1a2137
c54b158
8909ed6
8c06d2b
0ebd6f6
3edf931
1142f3a
c190177
0064014
f5be68a
df2b1a0
a0847cb
cc1d4af
cba7177
f85f4a6
a04984c
bcc2cbc
555023f
1333b0d
b68b2d4
c382cf6
b78d51a
11042e7
b1e5cc8
49d13d3
fa6b0a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,275 @@ | ||||||
| .. _qlora_finetune_label: | ||||||
|
|
||||||
| ============================= | ||||||
| Finetuning Llama2 with QLoRA | ||||||
| ============================= | ||||||
|
|
||||||
| In this tutorial, we'll learn about `QLoRA <https://arxiv.org/abs/2305.14314>`_, an enhancement on top of | ||||||
| `LoRA <https://arxiv.org/abs/2106.09685>`_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll | ||||||
| walk through how QLoRA can be utilized within TorchTune to finetune a Llama2-7b model in < 10 GB of memory. | ||||||
| It is highly recommended to first develop an understanding of :ref:`LoRA finetuning in TorchTune<lora_finetune_label>`. | ||||||
|
|
||||||
|
|
||||||
| .. grid:: 2 | ||||||
|
|
||||||
| .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn | ||||||
|
|
||||||
| * How QLoRA saves memory over LoRA finetuning | ||||||
| * An overview of QLoRA in TorchTune | ||||||
| * How to run a QLoRA finetune in TorchTune | ||||||
|
|
||||||
| .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites | ||||||
|
|
||||||
| * Be familiar with :ref:`TorchTune<overview_label>` | ||||||
| * Make sure to :ref:`install TorchTune<install_label>` | ||||||
| * Make sure you have downloaded the :ref:`Llama2-7B model weights<download_llama_label>` | ||||||
| * Be familiar with :ref:`LoRA in torchtune<lora_finetune_label>` | ||||||
|
|
||||||
| What is QLoRA? | ||||||
| --------------- | ||||||
|
|
||||||
| `QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable further | ||||||
| memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are | ||||||
| low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of | ||||||
| the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and | ||||||
|
||||||
| the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and | |
| the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or bf16), and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that'd be nice: basically take the diagram in the LoRA tutorial demonstrating full finetune -> LoRA and add one more for LoRA -> QLoRA. (The diagrams take a bit of time so I feel this is more of a nice-to-have at this point.) But if you're interested let me know and I can dig up the original
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely interested! Will punt this out to a follow up though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I feel like you're linking to the paper too many times, usually I just do it once upon introduction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah, I think this is fine. you don't know which paragraph someone will start reading at if they're skimming or jumping around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should link every other reference to the paper to keep people on their toes.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤯
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth giving some more details on either of these two optimizations, or do you think it's too in the weeds?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO its too in the weeds and not worth it to just re explain if its in the paper. Can add a line directing folks to the paper, but IMO its already clear enough to read the paper for this sort of detail
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we want TorchAO -> torchao
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. Also might remove checkpoint conversion since idk that it's really an ecosystem thing in the same way the other items are
| TorchTune and beyond (i.e. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done | |
| TorchTune and beyond (e.g. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this sentence could be a little unclear, seemingly implying that the way we avoid peaking memory is by maintaining an entire bf16/fp32 copy on GPU.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml compile=True | |
| tune run lora_finetune_single_device --config llama2/7B_qlora_single_device.yaml compile=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh shoot, .yaml is not correct we need to remove that as well.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this now that you've added the legend
| A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve). | |
| A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw you can also add explicit labels to the two lines in the figure, e.g. iconic-pasma-57 -> LoRA and azure-bird-56 -> QLoRA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to be annoying, but can you filter the x-axis to [0, 1000] or something in wandb and reupload? Otherwise it looks weird that one is running longer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you should find and replace TorchTune -> torchtune as we have done it everywhere else since after you first opened this PR