@@ -106,6 +106,7 @@ typedef struct {
106106 float norm_eps; // epsilon used in layernorm, e.g. 1e-5
107107 float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3)
108108 bool use_biases; // we always allocate memory for biases; to match llama3 they are not used
109+ bool tied_weights; // untied for large models (3.1 8B/70B/405B), tied for small (3.2 1B/3B)
109110} LLama3Config;
110111
111112// the parameters of the model
@@ -153,7 +154,12 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, LLama3Co
153154 size_t ffn_channels = hidden_dim * 2 ; // c_fc + c_fc2 concatenated
154155 // now populate the parameter sizes
155156 param_sizes[0 ] = Vp * C; // wte
156- param_sizes[1 ] = Vp * C; // (3) lm_head (final classifier layer weights)
157+ if (config.tied_weights ) {
158+ param_sizes[1 ] = 0 ; // no lm_head with tied weights
159+ } else {
160+ param_sizes[1 ] = Vp * C; // (3) lm_head (final classifier layer weights)
161+ }
162+
157163 param_sizes[2 ] = L * C; // ln1w
158164 param_sizes[3 ] = L * C; // ln1b; (1) all biases are zero it's ok
159165 param_sizes[4 ] = L * (qkv_channels) * C; // qkvw
@@ -195,6 +201,10 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen
195201 *(ptrs[i]) = (floatX*)params_memory_iterator;
196202 params_memory_iterator += param_elements[i] * param_sizeof[i];
197203 }
204+ // tied weights?
205+ if (param_elements[1 ] == 0 ) {
206+ params->wlmhead = nullptr ;
207+ }
198208 return params_memory;
199209}
200210
@@ -506,8 +516,9 @@ void llama3_write_to_checkpoint(LLama3 *model, const char* checkpoint_path) {
506516 model_header[7 ] = model->config .channels ;
507517 model_header[8 ] = model->config .multiple_of ;
508518 model_header[9 ] = model->config .use_scaled_rope ;
509- model_header[10 ] = 3 ;
510- model_header[11 ] = 1 ;
519+ model_header[10 ] = model->config .tied_weights ;
520+ model_header[11 ] = 3 ;
521+ model_header[12 ] = model->config .tied_weights ? 2 : 1 ;
511522 fwriteCheck (model_header, sizeof (int ), 256 , model_file);
512523 float float_header[256 ];
513524 float_header[0 ] = model->config .ffn_dim_multiplier ;
@@ -580,8 +591,9 @@ void llama3_build_from_checkpoint(LLama3 *model, const char* checkpoint_path, bo
580591 model->config .multiple_of = header_int[8 ];
581592 model->config .use_scaled_rope = header_int[9 ];
582593 model->config .use_biases = false ;
583- int major_version = header_int[10 ]; // currently unused, e.g. 3
584- int minor_version = header_int[11 ]; // currently unused, e.g. 1 (so Llama 3.1)
594+ model->config .tied_weights = header_int[10 ];
595+ int major_version = header_int[11 ]; // currently unused, e.g. 3
596+ int minor_version = header_int[12 ]; // 1 or 2
585597 // now the float section
586598 model->config .ffn_dim_multiplier = header_float[0 ];
587599 model->config .norm_eps = header_float[1 ];
@@ -740,7 +752,9 @@ void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) {
740752 }
741753 }
742754
743- matmul_forward_cublaslt (acts.output , acts.lnf , params.wlmhead , NULL , B, T, C, Vp, main_stream);
755+ floatX* lm_head = model->config .tied_weights ? params.wte : params.wlmhead ;
756+ matmul_forward_cublaslt (acts.output , acts.lnf , lm_head, NULL , B, T, C, Vp, main_stream);
757+
744758 cudaCheck (cudaDeviceSynchronize ());
745759}
746760
@@ -836,7 +850,10 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets,
836850 // technically that is a small, inline backward() pass of calculating
837851 // total, final loss as the mean over all losses over all (B,T) positions in the batch
838852 // next: backward the classifier matmul
839- matmul_backward (model->acts .scratch_bt4c , grads.wlmhead , NULL , acts.output , acts.lnf , params.wlmhead , NULL , B, T, C, Vp, main_stream);
853+ floatX* w_lm_head = model->config .tied_weights ? params.wte : params.wlmhead ;
854+ floatX* g_lm_head = model->config .tied_weights ? grads.wte : grads.wlmhead ;
855+
856+ matmul_backward (model->acts .scratch_bt4c , g_lm_head, NULL , acts.output , acts.lnf , w_lm_head, NULL , B, T, C, Vp, main_stream);
840857 // backward the final layernorm
841858 floatX* residual = acts.residual3 + (L-1 ) * B * T * C; // last residual is in residual3
842859 rmsnorm_backward (dresidual, grads.lnfw , scratchF, model->acts .scratch_bt4c , residual, params.lnfw , acts.lnf_rstd , B, T, C, main_stream);
@@ -1076,6 +1093,8 @@ void llama3_update(LLama3 *model, float learning_rate, float beta1, float beta2,
10761093 }
10771094
10781095 ShardInfo tensor = llama3_get_tensor_at_layer (model, 0 , i);
1096+ if (tensor.size == 0 )
1097+ continue ;
10791098 ShardInfo shard = multi_gpu_get_shard_offset (tensor.size , multi_gpu_config, 1 );
10801099 ptrdiff_t local_offset_full = tensor.offset + shard.offset ;
10811100 ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes ;
@@ -1144,6 +1163,10 @@ float llama3_estimate_mfu(LLama3 *model, int num_tokens, float dt) {
11441163 second is the attention matmul, which is also usually a small contribution.
11451164 */
11461165 size_t N = model->num_parameters ;
1166+ if (!model->config .tied_weights ) {
1167+ N -= model->param_elements [0 ]; // remove embedding parameters, which can be significant at 128k vocab
1168+ }
1169+
11471170 int L = model->config .num_layers ;
11481171 int C = model->config .channels ;
11491172 int T = model->seq_len ;
0 commit comments