Skip to content

Commit 2a25de8

Browse files
committed
enable tied embeddings
1 parent 49cef1d commit 2a25de8

File tree

3 files changed

+43
-13
lines changed

3 files changed

+43
-13
lines changed

test_llama3.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size
8585
*(ptrs[i]) = params_memory_iterator;
8686
params_memory_iterator += param_sizes[i];
8787
}
88+
if(param_sizes[1] == 0) {
89+
params->wlmhead = nullptr;
90+
}
8891
return params_memory;
8992
}
9093

train_llama3.cu

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ typedef struct {
106106
float norm_eps; // epsilon used in layernorm, e.g. 1e-5
107107
float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3)
108108
bool use_biases; // we always allocate memory for biases; to match llama3 they are not used
109+
bool tied_weights; // untied for large models (3.1 8B/70B/405B), tied for small (3.2 1B/3B)
109110
} LLama3Config;
110111

111112
// the parameters of the model
@@ -153,7 +154,12 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, LLama3Co
153154
size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated
154155
// now populate the parameter sizes
155156
param_sizes[0] = Vp * C; // wte
156-
param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights)
157+
if(config.tied_weights) {
158+
param_sizes[1] = 0; // no lm_head with tied weights
159+
} else {
160+
param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights)
161+
}
162+
157163
param_sizes[2] = L * C; // ln1w
158164
param_sizes[3] = L * C; // ln1b; (1) all biases are zero it's ok
159165
param_sizes[4] = L * (qkv_channels) * C; // qkvw
@@ -195,6 +201,10 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen
195201
*(ptrs[i]) = (floatX*)params_memory_iterator;
196202
params_memory_iterator += param_elements[i] * param_sizeof[i];
197203
}
204+
// tied weights?
205+
if(param_elements[1] == 0) {
206+
params->wlmhead = nullptr;
207+
}
198208
return params_memory;
199209
}
200210

@@ -506,8 +516,9 @@ void llama3_write_to_checkpoint(LLama3 *model, const char* checkpoint_path) {
506516
model_header[7] = model->config.channels;
507517
model_header[8] = model->config.multiple_of;
508518
model_header[9] = model->config.use_scaled_rope;
509-
model_header[10] = 3;
510-
model_header[11] = 1;
519+
model_header[10] = model->config.tied_weights;
520+
model_header[11] = 3;
521+
model_header[12] = model->config.tied_weights ? 2 : 1;
511522
fwriteCheck(model_header, sizeof(int), 256, model_file);
512523
float float_header[256];
513524
float_header[0] = model->config.ffn_dim_multiplier;
@@ -580,8 +591,9 @@ void llama3_build_from_checkpoint(LLama3 *model, const char* checkpoint_path, bo
580591
model->config.multiple_of = header_int[8];
581592
model->config.use_scaled_rope = header_int[9];
582593
model->config.use_biases = false;
583-
int major_version = header_int[10]; // currently unused, e.g. 3
584-
int minor_version = header_int[11]; // currently unused, e.g. 1 (so Llama 3.1)
594+
model->config.tied_weights = header_int[10];
595+
int major_version = header_int[11]; // currently unused, e.g. 3
596+
int minor_version = header_int[12]; // 1 or 2
585597
// now the float section
586598
model->config.ffn_dim_multiplier = header_float[0];
587599
model->config.norm_eps = header_float[1];
@@ -740,7 +752,9 @@ void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) {
740752
}
741753
}
742754

743-
matmul_forward_cublaslt(acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream);
755+
floatX* lm_head = model->config.tied_weights ? params.wte : params.wlmhead;
756+
matmul_forward_cublaslt(acts.output, acts.lnf, lm_head, NULL, B, T, C, Vp, main_stream);
757+
744758
cudaCheck(cudaDeviceSynchronize());
745759
}
746760

@@ -836,7 +850,10 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets,
836850
// technically that is a small, inline backward() pass of calculating
837851
// total, final loss as the mean over all losses over all (B,T) positions in the batch
838852
// next: backward the classifier matmul
839-
matmul_backward(model->acts.scratch_bt4c, grads.wlmhead, NULL, acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream);
853+
floatX* w_lm_head = model->config.tied_weights ? params.wte : params.wlmhead;
854+
floatX* g_lm_head = model->config.tied_weights ? grads.wte : grads.wlmhead;
855+
856+
matmul_backward(model->acts.scratch_bt4c, g_lm_head, NULL, acts.output, acts.lnf, w_lm_head, NULL, B, T, C, Vp, main_stream);
840857
// backward the final layernorm
841858
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
842859
rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream);
@@ -1076,6 +1093,8 @@ void llama3_update(LLama3 *model, float learning_rate, float beta1, float beta2,
10761093
}
10771094

10781095
ShardInfo tensor = llama3_get_tensor_at_layer(model, 0, i);
1096+
if(tensor.size == 0)
1097+
continue;
10791098
ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1);
10801099
ptrdiff_t local_offset_full = tensor.offset + shard.offset;
10811100
ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes;
@@ -1144,6 +1163,10 @@ float llama3_estimate_mfu(LLama3 *model, int num_tokens, float dt) {
11441163
second is the attention matmul, which is also usually a small contribution.
11451164
*/
11461165
size_t N = model->num_parameters;
1166+
if(!model->config.tied_weights) {
1167+
N -= model->param_elements[0]; // remove embedding parameters, which can be significant at 128k vocab
1168+
}
1169+
11471170
int L = model->config.num_layers;
11481171
int C = model->config.channels;
11491172
int T = model->seq_len;

train_llama3.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ def __init__(self, config):
312312
ln_f = RMSNorm(config.n_embd, config.norm_eps),
313313
))
314314
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
315+
if config.tied_embeddings:
316+
self.transformer.wte.weight = self.lm_head.weight
315317

316318
# init all weights, use a torch rng object to be very careful
317319
self.init_rng = torch.Generator()
@@ -880,7 +882,7 @@ def write_bf16(tensor, file):
880882
b = t.numpy().tobytes()
881883
file.write(b)
882884

883-
def write_tensors(model_tensors, L, file, dtype):
885+
def write_tensors(model_tensors, L, tied, file, dtype):
884886
# writes LLaMA 3 model's weights to a binary file
885887
# things get a bit more complicated though:
886888
# 1) We want to maintain the ability to finetune just the biases in the C code
@@ -898,7 +900,8 @@ def write_tensors(model_tensors, L, file, dtype):
898900
assert dtype in {"float32", "bfloat16"}
899901
write_fun = write_fp32 if dtype == "float32" else write_bf16
900902
write_fun(model_tensors["transformer.wte.weight"], file) # (V, C)
901-
write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here!
903+
if not tied:
904+
write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here!
902905
for i in range(L): # (L, C)
903906
write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
904907
for i in range(L): # (L, C)
@@ -958,8 +961,9 @@ def write_model(model, filename, dtype):
958961
header_int[7] = model.config.n_embd
959962
header_int[8] = model.config.multiple_of
960963
header_int[9] = int(model.config.use_scaled_rope)
961-
header_int[10] = int(model.config.version.split('.')[0]) # major version
962-
header_int[11] = int(model.config.version.split('.')[1]) # minor version
964+
header_int[10] = int(model.config.tied_embeddings)
965+
header_int[11] = int(model.config.version.split('.')[0]) # major version
966+
header_int[12] = int(model.config.version.split('.')[1]) # minor version
963967
# float section of the header
964968
header_float = torch.zeros(256, dtype=torch.float32)
965969
header_float[0] = model.config.ffn_dim_multiplier
@@ -971,7 +975,7 @@ def write_model(model, filename, dtype):
971975
with open(filename, "wb") as file:
972976
file.write(header_int.numpy().tobytes()) # int header
973977
file.write(header_float.numpy().tobytes()) # float header
974-
write_tensors(params, model.config.n_layer, file, dtype) # params
978+
write_tensors(params, model.config.n_layer, model.config.tied_embeddings, file, dtype) # params
975979
print(f"wrote {filename}")
976980

977981
def write_state(model, x, y, logits, loss, filename):
@@ -996,7 +1000,7 @@ def write_state(model, x, y, logits, loss, filename):
9961000
# loss (single float, result of the cross entropy loss)
9971001
write_fp32(loss.cpu(), file)
9981002
# gradients
999-
write_tensors(grads, model.config.n_layer, file, "float32")
1003+
write_tensors(grads, model.config.n_layer, model.config.tied_embeddings, file, "float32")
10001004
print(f"wrote {filename}")
10011005

10021006
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)