Skip to content

Commit efbab6f

Browse files
andrewor14maximegmd
authored andcommitted
Add QAT support for distributed finetuning (meta-pytorch#980)
1 parent 90fe7ec commit efbab6f

File tree

9 files changed

+1000
-6
lines changed

9 files changed

+1000
-6
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Config for multi-device QAT finetuning in qat_distributed.py
2+
# using a Llama2 7B model
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
7+
#
8+
# To launch on 4 devices, run the following command from root:
9+
# tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama2/7B_qat_full
10+
#
11+
# You can add specific overrides through the command line. For example
12+
# to override the checkpointer directory while launching training
13+
# you can run:
14+
# tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama2/7B_qat_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
15+
16+
17+
# Tokenizer
18+
tokenizer:
19+
_component_: torchtune.models.llama2.llama2_tokenizer
20+
path: /tmp/Llama-2-7b-hf/tokenizer.model
21+
22+
# Dataset
23+
dataset:
24+
_component_: torchtune.datasets.alpaca_dataset
25+
seed: null
26+
shuffle: True
27+
28+
# Model Arguments
29+
model:
30+
_component_: torchtune.models.llama2.llama2_7b
31+
32+
checkpointer:
33+
_component_: torchtune.utils.FullModelHFCheckpointer
34+
checkpoint_dir: /tmp/Llama-2-7b-hf
35+
checkpoint_files: [
36+
pytorch_model-00001-of-00002.bin,
37+
pytorch_model-00002-of-00002.bin
38+
]
39+
recipe_checkpoint: null
40+
output_dir: /tmp/Llama-2-7b-hf
41+
model_type: LLAMA2
42+
resume_from_checkpoint: False
43+
44+
# Fine-tuning arguments
45+
batch_size: 2
46+
epochs: 3
47+
optimizer:
48+
_component_: torch.optim.AdamW
49+
lr: 2e-5
50+
loss:
51+
_component_: torch.nn.CrossEntropyLoss
52+
max_steps_per_epoch: null
53+
gradient_accumulation_steps: 1
54+
55+
# QAT arguments
56+
quantizer:
57+
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer
58+
groupsize: 256
59+
60+
# Training env
61+
device: cuda
62+
63+
# Memory management
64+
enable_activation_checkpointing: True
65+
memory_efficient_fsdp_wrap: False
66+
67+
# Reduced precision
68+
dtype: bf16
69+
70+
# Logging
71+
metric_logger:
72+
_component_: torchtune.utils.metric_logging.DiskLogger
73+
log_dir: ${output_dir}
74+
output_dir: /tmp/alpaca-llama2-finetune
75+
log_every_n_steps: 1
76+
log_peak_memory_stats: False
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Config for multi-device QAT finetuning in qat_distributed.py
2+
# using a Llama3 8B Instruct model
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token <HF_TOKEN>
7+
#
8+
# To launch on 4 devices, run the following command from root:
9+
# tune run --nproc_per_node 4 qat_distributed --config llama3/8B_qat_full
10+
#
11+
# You can add specific overrides through the command line. For example
12+
# to override the checkpointer directory while launching training
13+
# you can run:
14+
# tune run --nproc_per_node 4 qat_distributed --config llama3/8B_qat_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
15+
16+
# Tokenizer
17+
tokenizer:
18+
_component_: torchtune.models.llama3.llama3_tokenizer
19+
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
20+
21+
# Dataset
22+
dataset:
23+
_component_: torchtune.datasets.alpaca_dataset
24+
seed: null
25+
shuffle: True
26+
27+
# Model Arguments
28+
model:
29+
_component_: torchtune.models.llama3.llama3_8b
30+
31+
checkpointer:
32+
_component_: torchtune.utils.FullModelMetaCheckpointer
33+
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
34+
checkpoint_files: [
35+
consolidated.00.pth
36+
]
37+
recipe_checkpoint: null
38+
output_dir: /tmp/Meta-Llama-3-8B-Instruct/
39+
model_type: LLAMA3
40+
resume_from_checkpoint: False
41+
42+
# Fine-tuning arguments
43+
batch_size: 2
44+
epochs: 3
45+
46+
# QAT arguments
47+
quantizer:
48+
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer
49+
groupsize: 256
50+
51+
optimizer:
52+
_component_: torch.optim.AdamW
53+
lr: 2e-5
54+
foreach: False
55+
56+
loss:
57+
_component_: torch.nn.CrossEntropyLoss
58+
max_steps_per_epoch: null
59+
gradient_accumulation_steps: 1
60+
61+
# Training env
62+
device: cuda
63+
64+
# Memory management
65+
enable_activation_checkpointing: True
66+
memory_efficient_fsdp_wrap: True
67+
68+
# Reduced precision
69+
dtype: bf16
70+
71+
# Logging
72+
metric_logger:
73+
_component_: torchtune.utils.metric_logging.DiskLogger
74+
log_dir: ${output_dir}
75+
output_dir: /tmp/alpaca-llama3-finetune
76+
log_every_n_steps: 1
77+
log_peak_memory_stats: False

0 commit comments

Comments
 (0)