Skip to content

Commit 08b0f91

Browse files
committed
command-line overwrite to forcibly untie embeddings for llama3.2 models
1 parent 5b3bac2 commit 08b0f91

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

.github/workflows/ci_gpu.yml

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ jobs:
119119
run: ./test_gpt2fp32cu
120120

121121
build-and-test-llama3:
122+
name: Build and test LLama3.2 1B
122123
runs-on: ubicloud-gpu-standard-1-latest
123124
env:
124125
HF_TOKEN: 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
@@ -154,18 +155,52 @@ jobs:
154155
- name: Build BF16 precision
155156
run: PRECISION=BF16 make train_llama3cu test_llama3cu
156157

157-
- name: Run default
158+
- name: Run default (BF16)
158159
run: ./test_llama3cu
159160

160-
- name: Run no recompute GeLU
161+
- name: Run no recompute GeLU (BF16)
161162
run: ./test_llama3cu -r 0
162163

163-
- name: Run no master weights
164+
- name: Run no master weights (BF16)
164165
run: ./test_llama3cu -w 0
165166

166-
- name: Run recompute LN
167+
- name: Run recompute LN (BF16)
167168
run: ./test_llama3cu -r 2
168169

170+
build-and-test-llama3-untied:
171+
name: Build and test LLama3.2 1B with untie weights
172+
runs-on: ubicloud-gpu-standard-1-latest
173+
env:
174+
HF_TOKEN: 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
175+
steps:
176+
- name: Checkout code
177+
uses: actions/checkout@v4
178+
- run: echo "::add-mask::$HF_TOKEN"
179+
180+
- name: Install OpenMP
181+
run: sudo apt-get update && sudo apt-get install -y libomp-dev
182+
183+
- name: Install dependencies
184+
run: pip install -r requirements.txt
185+
186+
- name: Run preprocessing
187+
run: python dev/data/tinyshakespeare.py --model_desc llama-3
188+
189+
- name: Train model
190+
run: python train_llama3.py --write_tensors 1 --dtype float32 --untie 1
191+
192+
- name: Build FP32 precision
193+
run: PRECISION=FP32 make test_llama3cu
194+
195+
- name: Run default
196+
run: ./test_llama3cu
197+
198+
- name: Build BF16 precision
199+
run: PRECISION=BF16 make train_llama3cu test_llama3cu
200+
201+
- name: Run default
202+
run: ./test_llama3cu
203+
169204
unit-tests-gpu:
170205
runs-on: ubicloud-gpu-standard-1-latest
171206

test_llama3.cu

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ int main(int argc, char *argv[]) {
301301
}
302302

303303
// expected losses are as follows, from Python (without CPUOffload)
304-
float expected_losses[10] = {
304+
float expected_losses_untied[10] = {
305305
4.849688f,
306306
3.070303f,
307307
1.711614f,
@@ -313,6 +313,20 @@ int main(int argc, char *argv[]) {
313313
0.355562f,
314314
0.334824f
315315
};
316+
float expected_losses_tied[10] = {
317+
4.849688f,
318+
3.072875f,
319+
1.714160f,
320+
1.060224f,
321+
0.596433f,
322+
0.431257f,
323+
0.373330f,
324+
0.361544f,
325+
0.357920f,
326+
0.336123f
327+
};
328+
329+
float* expected_losses = model.config.tied_weights ? expected_losses_tied : expected_losses_untied;
316330

317331
// compare
318332
for (int i = 0; i < 10; i++) {

train_llama3.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,16 @@ def unpermute(w, n_heads, dim1, dim2):
435435
return checkpoint
436436

437437
@classmethod
438-
def from_pretrained_llama3_hf(cls, model_id):
438+
def from_pretrained_llama3_hf(cls, model_id, untie):
439439
"""Loads pretrained LLaMA model weights from HuggingFace"""
440440
from transformers import AutoModelForCausalLM, AutoTokenizer
441441
model_args = MODEL_DICT[model_id]
442+
if untie:
443+
if not model_args.tied_embeddings:
444+
print("Model embeddings are not tied, --untie has no effect.")
445+
else:
446+
print("Untying token embeddings and LM head.")
447+
model_args.tied_embeddings = False
442448

443449
model = AutoModelForCausalLM.from_pretrained(model_id)
444450
checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args)
@@ -1026,6 +1032,7 @@ def print0(*args, **kwargs):
10261032
parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
10271033
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
10281034
parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model")
1035+
parser.add_argument("--untie", type=int, default=False, help="Untie token embeddings and LM-head, even if they are tied in the checkpoint.")
10291036
# token layout for each step of the optimization
10301037
parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
10311038
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
@@ -1131,7 +1138,7 @@ def print0(*args, **kwargs):
11311138

11321139
# init the model
11331140
if args.use_hf:
1134-
model = LLaMA.from_pretrained_llama3_hf(args.model)
1141+
model = LLaMA.from_pretrained_llama3_hf(args.model, args.untie)
11351142
else: # use Meta's checkpoint
11361143
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
11371144
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"

0 commit comments

Comments
 (0)