Skip to content

Commit c69fba1

Browse files
authored
Llama3-70B LoRA multi GPU (#802)
1 parent a9180b5 commit c69fba1

File tree

6 files changed

+194
-8
lines changed

6 files changed

+194
-8
lines changed

README.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
 
77
 
88

9-
torchtune now officially supports Meta Llama3! Check out our recipes for Llama3-8B with LoRA, QLoRA and Full fine-tune in the [Llama3](#llama3) section! 🚀 🦙
9+
torchtune now officially supports Meta Llama3! Check out our recipes for Llama3-8B with LoRA, QLoRA and Full fine-tune in the [Llama3](#llama3) section! We also support 70B fine-tuning with LoRA! 🚀 🦙
1010

1111
# torchtune
1212

@@ -44,7 +44,7 @@ torchtune currently supports the following models.
4444

4545
| Model | Sizes |
4646
|-----------------------------------------------|-----------|
47-
| [Llama3](https://llama.meta.com/llama3) | 8B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] |
47+
| [Llama3](https://llama.meta.com/llama3) | 8B, 70B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] |
4848
| [Llama2](https://llama.meta.com/llama2/) | 7B, 13B, 70B [[models](torchtune/models/llama2/_model_builders.py), [configs](recipes/configs/llama2/)] |
4949
| [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] |
5050
| [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] |
@@ -86,35 +86,41 @@ This table captures the minimum memory requirements for our different recipes us
8686

8787
## Llama3
8888

89-
torchtune supports fine-tuning for the Llama3 8B models with support for 70B on its way. We currently support LoRA, QLoRA and Full-finetune on a single GPU as well as LoRA and Full fine-tune on multiple devices. For all the details, take a look at our [tutorial](https://pytorch.org/torchtune/stable/tutorials/llama3.html).
89+
torchtune supports fine-tuning for the Llama3 8B and 70B models. We currently support LoRA, QLoRA and Full-finetune on a single GPU as well as LoRA and Full fine-tune on multiple devices for the 8B model, and LoRA on multiple devices for the 70B model. For all the details, take a look at our [tutorial](https://pytorch.org/torchtune/stable/tutorials/llama3.html).
9090

9191

92-
In our initial experiments, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.
92+
In our initial experiments for Llama3-8B, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.
9393

94-
- LoRA on a single GPU.
94+
- 8B LoRA on a single GPU.
9595

9696
```bash
9797
tune run lora_finetune_single_device --config llama3/8B_lora_single_device
9898
```
9999

100-
- QLoRA on a single GPU
100+
- 8B QLoRA on a single GPU
101101

102102
```bash
103103
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
104104
```
105105

106-
- LoRA on 2 GPUs
106+
- 8B LoRA on 2 GPUs
107107

108108
```bash
109109
tune run --nproc_per_node 4 lora_finetune_distributed --config llama3/8B_lora
110110
```
111111

112-
- Full fine-tune on 2 GPUs
112+
- 8B Full fine-tune on 2 GPUs
113113

114114
```bash
115115
tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_full
116116
```
117117

118+
- 70B LoRA finetune on 8 GPUs
119+
120+
```bash
121+
tune run --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/llama3/70B_lora.yaml
122+
```
123+
118124

119125
 
120126

docs/source/api_ref_models.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ All models from the `Llama3 family <https://llama.meta.com/llama3/>`_.
1919
:nosignatures:
2020

2121
llama3.llama3_8b
22+
llama3.llama3_70b
2223
llama3.lora_llama3_8b
2324
llama3.qlora_llama3_8b
25+
llama3.lora_llama3_70b
2426

2527

2628
llama2
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Config for multi-device LoRA in lora_finetune_distributed.py
2+
# using a Llama3 70B model
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download meta-llama/Meta-Llama-3-70b --hf-token <TOKEN> --output-dir /tmp/Meta-Llama-3-70b --ignore-patterns "original/consolidated*"
7+
#
8+
# This config needs 8 GPUs to run
9+
# # tune run --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/llama3/70B_lora.yaml
10+
#
11+
12+
# Model Arguments
13+
model:
14+
_component_: torchtune.models.llama3.lora_llama3_70b
15+
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
16+
apply_lora_to_mlp: False
17+
apply_lora_to_output: False
18+
lora_rank: 16
19+
lora_alpha: 32
20+
21+
tokenizer:
22+
_component_: torchtune.models.llama3.llama3_tokenizer
23+
path: /tmp/Meta-Llama-3-70b/original/tokenizer.model
24+
25+
checkpointer:
26+
_component_: torchtune.utils.FullModelHFCheckpointer
27+
checkpoint_dir: /tmp/Meta-Llama-3-70b
28+
checkpoint_files: [
29+
model-00001-of-00030.safetensors,
30+
model-00002-of-00030.safetensors,
31+
model-00003-of-00030.safetensors,
32+
model-00004-of-00030.safetensors,
33+
model-00005-of-00030.safetensors,
34+
model-00006-of-00030.safetensors,
35+
model-00007-of-00030.safetensors,
36+
model-00008-of-00030.safetensors,
37+
model-00009-of-00030.safetensors,
38+
model-00010-of-00030.safetensors,
39+
model-00011-of-00030.safetensors,
40+
model-00012-of-00030.safetensors,
41+
model-00013-of-00030.safetensors,
42+
model-00014-of-00030.safetensors,
43+
model-00015-of-00030.safetensors,
44+
model-00016-of-00030.safetensors,
45+
model-00017-of-00030.safetensors,
46+
model-00018-of-00030.safetensors,
47+
model-00019-of-00030.safetensors,
48+
model-00020-of-00030.safetensors,
49+
model-00021-of-00030.safetensors,
50+
model-00022-of-00030.safetensors,
51+
model-00023-of-00030.safetensors,
52+
model-00024-of-00030.safetensors,
53+
model-00025-of-00030.safetensors,
54+
model-00026-of-00030.safetensors,
55+
model-00027-of-00030.safetensors,
56+
model-00028-of-00030.safetensors,
57+
model-00029-of-00030.safetensors,
58+
model-00030-of-00030.safetensors,
59+
]
60+
recipe_checkpoint: null
61+
output_dir: /tmp/Meta-Llama-3-70b
62+
model_type: LLAMA3
63+
resume_from_checkpoint: False
64+
65+
# Dataset and Sampler
66+
dataset:
67+
_component_: torchtune.datasets.alpaca_dataset
68+
train_on_input: True
69+
seed: null
70+
shuffle: True
71+
batch_size: 2
72+
73+
# Optimizer and Scheduler
74+
optimizer:
75+
_component_: torch.optim.AdamW
76+
weight_decay: 0.01
77+
lr: 3e-4
78+
lr_scheduler:
79+
_component_: torchtune.modules.get_cosine_schedule_with_warmup
80+
num_warmup_steps: 100
81+
82+
loss:
83+
_component_: torch.nn.CrossEntropyLoss
84+
85+
# Training
86+
epochs: 1
87+
max_steps_per_epoch: null
88+
gradient_accumulation_steps: 1
89+
90+
# Logging
91+
output_dir: /tmp/lora_finetune_output
92+
metric_logger:
93+
_component_: torchtune.utils.metric_logging.DiskLogger
94+
log_dir: ${output_dir}
95+
log_every_n_steps: null
96+
97+
# Environment
98+
device: cuda
99+
dtype: bf16
100+
enable_activation_checkpointing: True

torchtune/_recipe_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class Recipe:
107107
Config(name="llama2/7B_lora", file_path="llama2/7B_lora.yaml"),
108108
Config(name="llama2/13B_lora", file_path="llama2/13B_lora.yaml"),
109109
Config(name="llama2/70B_lora", file_path="llama2/70B_lora.yaml"),
110+
Config(name="llama3/70B_lora", file_path="llama3/70B_lora.yaml"),
110111
Config(name="llama3/8B_lora", file_path="llama3/8B_lora.yaml"),
111112
Config(name="mistral/7B_lora", file_path="mistral/7B_lora.yaml"),
112113
],

torchtune/models/llama3/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from ._component_builders import llama3, lora_llama3
88

99
from ._model_builders import ( # noqa
10+
llama3_70b,
1011
llama3_8b,
1112
llama3_tokenizer,
13+
lora_llama3_70b,
1214
lora_llama3_8b,
1315
qlora_llama3_8b,
1416
)
@@ -17,9 +19,11 @@
1719
__all__ = [
1820
"llama3",
1921
"llama3_8b",
22+
"llama3_70b",
2023
"llama3_tokenizer",
2124
"lora_llama3",
2225
"lora_llama3_8b",
26+
"lora_llama3_70b",
2327
"qlora_llama3_8b",
2428
"scale_hidden_dim_for_mlp",
2529
]

torchtune/models/llama3/_model_builders.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,27 @@ def llama3_8b() -> TransformerDecoder:
4444
)
4545

4646

47+
def llama3_70b() -> TransformerDecoder:
48+
"""
49+
Builder for creating a Llama3 model initialized w/ the default 70B parameter values.
50+
51+
Returns:
52+
TransformerDecoder: Instantiation of Llama3 70 model
53+
"""
54+
return llama3(
55+
vocab_size=128_256,
56+
num_layers=80,
57+
num_heads=64,
58+
num_kv_heads=8,
59+
embed_dim=8192,
60+
max_seq_len=4096,
61+
intermediate_dim=28672,
62+
attn_dropout=0.0,
63+
norm_eps=1e-5,
64+
rope_base=500000.0,
65+
)
66+
67+
4768
def llama3_tokenizer(path: str) -> TikTokenTokenizer:
4869
tiktoken = TikTokenTokenizer(path)
4970
tiktoken.pad_id = 0
@@ -100,6 +121,58 @@ def lora_llama3_8b(
100121
quantize_base=quantize_base,
101122
)
102123

124+
125+
def lora_llama3_70b(
126+
lora_attn_modules: List[LORA_ATTN_MODULES],
127+
apply_lora_to_mlp: bool = False,
128+
apply_lora_to_output: bool = False,
129+
lora_rank: int = 8,
130+
lora_alpha: float = 16,
131+
quantize_base: bool = False,
132+
) -> TransformerDecoder:
133+
"""
134+
Builder for creating a Llama3 70B model with LoRA enabled.
135+
136+
The Llama3 defaults are the same as in :func:`~torchtune.models.llama3.llama3_70b`,
137+
while LoRA default params are based on
138+
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
139+
140+
Args:
141+
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
142+
LoRA should be applied to in each self-attention block. Options are
143+
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
144+
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
145+
Default: False
146+
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
147+
Default: False
148+
lora_rank (int): rank of each low-rank approximation
149+
lora_alpha (float): scaling factor for the low-rank approximation
150+
quantize_base (bool): Whether to quantize base model weights
151+
152+
Returns:
153+
TransformerDecoder: Instantiation of Llama3 8B model with LoRA applied
154+
"""
155+
return lora_llama3(
156+
lora_attn_modules=lora_attn_modules,
157+
apply_lora_to_mlp=apply_lora_to_mlp,
158+
apply_lora_to_output=apply_lora_to_output,
159+
vocab_size=128_256,
160+
num_layers=80,
161+
num_heads=64,
162+
num_kv_heads=8,
163+
embed_dim=8192,
164+
max_seq_len=4096,
165+
intermediate_dim=28672,
166+
attn_dropout=0.0,
167+
norm_eps=1e-5,
168+
rope_base=500000.0,
169+
lora_rank=lora_rank,
170+
lora_alpha=lora_alpha,
171+
lora_dropout=0.05,
172+
quantize_base=quantize_base,
173+
)
174+
175+
103176
qlora_llama3_8b = partial(lora_llama3_8b, quantize_base=True)
104177

105178
qlora_llama3_8b.__doc__ = """

0 commit comments

Comments
 (0)