Skip to content

Commit b1e0666

Browse files
authored
Added multinode capabilities for distributed training (#19)
* multinode fsdp finetuning * full dpo
1 parent 0fb4536 commit b1e0666

File tree

5 files changed

+918
-12
lines changed

5 files changed

+918
-12
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Config for multi-device full DPO alignment in full_dpo_distributed.py
2+
# using a Llama2 7B model
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --ignore-patterns "*.safetensors" --hf-token <HF_TOKEN>
7+
#
8+
# To launch on 2 devices, run the following command from root:
9+
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama2/7B_lora_dpo
10+
#
11+
# You can add specific overrides through the command line. For example
12+
# to override the checkpointer directory while launching training
13+
# you can run:
14+
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama2/7B_lora_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
15+
#
16+
# This config works best when the model is being fine-tuned on 2+ GPUs.
17+
# For single device LoRA DPO alignment please use 7B_lora_dpo_single_device.yaml
18+
19+
# Model Arguments
20+
model:
21+
_component_: torchtune.models.sarvam1.sarvam1
22+
23+
# Tokenizer
24+
tokenizer:
25+
_component_: torchtune.models.llama2.llama2_tokenizer
26+
path: /projects/data/rahul_sarvam_ai/nemo_models/sarvam-1-pt/tokenizer.model
27+
max_seq_len: 8192
28+
output_dir: /projects/data/rahul_sarvam_ai/torchtune_models/dpo_test
29+
30+
checkpointer:
31+
_component_: torchtune.training.FullModelHFCheckpointer
32+
checkpoint_dir: /projects/data/rahul_sarvam_ai/models/sarvam-1-torchtune-sft
33+
checkpoint_files:
34+
[model-00001-of-00002.safetensors, model-00002-of-00002.safetensors]
35+
recipe_checkpoint: null
36+
output_dir: ${output_dir}
37+
model_type: LLAMA3
38+
safe_serialization: True
39+
resume_from_checkpoint: False
40+
save_adapter_weights_only: False
41+
42+
# Dataset and Sampler
43+
dataset:
44+
_component_: torchtune.datasets.preference_dataset
45+
source: allenai/llama-3.1-tulu-3-70b-preference-mixture
46+
split: train
47+
seed: null
48+
shuffle: True
49+
batch_size: 1
50+
51+
# Optimizer and Scheduler
52+
optimizer:
53+
_component_: torch.optim.AdamW
54+
fused: True
55+
weight_decay: 0.01
56+
lr: 1e-5
57+
lr_scheduler:
58+
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
59+
num_warmup_steps: 100
60+
61+
loss:
62+
_component_: torchtune.rlhf.loss.DPOLoss
63+
beta: 0.1
64+
label_smoothing: 0
65+
66+
# Training
67+
epochs: 1
68+
max_steps_per_epoch: 1_000_000
69+
gradient_accumulation_steps: 8 # Use to increase virtual batch size
70+
compile: False # pytorch compile, set to true for better perf/memory
71+
72+
# Logging
73+
metric_logger:
74+
_component_: torchtune.training.metric_logging.WandBLogger
75+
# the W&B project to log to
76+
project: torchtune
77+
log_every_n_steps: 10
78+
log_peak_memory_stats: True
79+
80+
# Environment
81+
device: cuda
82+
dtype: bf16
83+
84+
# Memory management
85+
enable_activation_checkpointing: True # True reduces memory
86+
enable_activation_offloading: False # True reduces memory

recipes/configs/sarvam1/full_finetune.yaml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ model:
2323

2424
# Tokenizer
2525
tokenizer:
26-
_component_: torchtune.models.sarvam1.sarvam1_tokenizer
26+
_component_: torchtune.models.llama2.llama2_tokenizer
2727
path: /projects/data/rahul_sarvam_ai/nemo_models/sarvam-1-pt/tokenizer.model
2828
max_seq_len: 8192
2929

@@ -35,10 +35,10 @@ dataset:
3535
conversation_style: openai
3636
conversation_column: messages
3737
source: json
38-
packs_cache_path: /projects/data/mohit_sarvam_ai/torchtune/data/microsoft_agent
38+
packs_cache_path: /projects/data/rahul_sarvam_ai/torchtune_models/cache/ft_train_cache_phase_1
3939
data_files: [
40-
# /projects/data/mohit_sarvam_ai/torchtune/data/ft_train/sample.json,
41-
/projects/data/mohit_sarvam_ai/torchtune/data/microsoft-agent.jsonl
40+
/projects/data/mohit_sarvam_ai/torchtune/data/ft_train/sample.json,
41+
# /projects/data/mohit_sarvam_ai/torchtune/data/microsoft-agent.jsonl
4242
# /projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_0.json,
4343
# /projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_1.json,
4444
# /projects/data/mohit_sarvam_ai/torchtune/data/ft_train/data_2.json,
@@ -63,13 +63,14 @@ dataset:
6363
split: train
6464
seed: null
6565
shuffle: True
66-
output_dir: /projects/data/mohit_sarvam_ai/torchtune/output-sft/sarvam1-sft-microsft-agent
66+
output_dir: /projects/data/rahul_sarvam_ai/torchtune_models/sft-phase-1
6767

6868
checkpointer:
6969
_component_: torchtune.training.FullModelHFCheckpointer
7070
checkpoint_dir: /projects/data/rahul_sarvam_ai/nemo_models/sarvam-1-pt
7171
checkpoint_files: [
72-
pytorch_model.bin
72+
# pytorch_model.bin
73+
ft-model-00001-of-00001.safetensors
7374
]
7475
recipe_checkpoint: null
7576
output_dir: ${output_dir}
@@ -78,20 +79,20 @@ save_interval: 2000
7879
resume_from_checkpoint: False
7980

8081
# Fine-tuning arguments
81-
batch_size: 1
82+
batch_size: 8
8283
epochs: 1
8384
optimizer:
8485
_component_: torch.optim.AdamW
8586
fused: True
86-
lr: 7e-6
87+
lr: 3e-5
8788
weight_decay: 0.01
8889
betas: [0.9, 0.98]
8990
lr_scheduler:
9091
_component_: torchtune.training.lr_schedulers.get_linear_schedule_with_warmup
91-
min_lr: 1e6
92+
min_lr: 1e-8
9293
max_lr: ${optimizer.lr}
93-
num_warmup_steps: 500
94-
constant_steps: 500
94+
num_warmup_steps: 1000
95+
constant_steps: 0
9596

9697
loss:
9798
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

recipes/configs/sarvam1/lora_dpo.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ loss:
7171

7272
# Training
7373
epochs: 1
74-
max_steps_per_epoch: 1000_000
74+
max_steps_per_epoch: 1_000_000
7575
gradient_accumulation_steps: 4 # Use to increase virtual batch size
7676
compile: False # pytorch compile, set to true for better perf/memory
7777

0 commit comments

Comments
 (0)