Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ jobs:
run: ./test_gpt2cu

build-and-test-llama3:
name: Build and test LLama3.2 1B
runs-on: ubicloud-gpu-standard-1-latest
env:
HF_TOKEN: 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
Expand Down Expand Up @@ -150,18 +151,52 @@ jobs:
- name: Build BF16 precision
run: PRECISION=BF16 make train_llama3cu test_llama3cu

- name: Run default
- name: Run default (BF16)
run: ./test_llama3cu

- name: Run no recompute GeLU
- name: Run no recompute GeLU (BF16)
run: ./test_llama3cu -r 0

- name: Run no master weights
- name: Run no master weights (BF16)
run: ./test_llama3cu -w 0

- name: Run recompute LN
- name: Run recompute LN (BF16)
run: ./test_llama3cu -r 2

build-and-test-llama3-untied:
name: Build and test LLama3.2 1B with untie weights
runs-on: ubicloud-gpu-standard-1-latest
env:
HF_TOKEN: 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
steps:
- name: Checkout code
uses: actions/checkout@v4
- run: echo "::add-mask::$HF_TOKEN"

- name: Install OpenMP
run: sudo apt-get update && sudo apt-get install -y libomp-dev

- name: Install dependencies
run: pip install -r requirements.txt

- name: Run preprocessing
run: python dev/data/tinyshakespeare.py --model_desc llama-3

- name: Train model
run: python train_llama3.py --write_tensors 1 --dtype float32 --untie 1 --depth 10

- name: Build FP32 precision
run: PRECISION=FP32 make test_llama3cu

- name: Run default
run: ./test_llama3cu

- name: Build BF16 precision
run: PRECISION=BF16 make train_llama3cu test_llama3cu

- name: Run default
run: ./test_llama3cu

unit-tests-gpu:
runs-on: ubicloud-gpu-standard-1-latest

Expand Down
2 changes: 1 addition & 1 deletion llmc/rmsnorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ __global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX*
__syncthreads();

int idx = blockIdx.x * blockDim.y + threadIdx.y;
if(idx > N) return;
if(idx >= N) return;

// adjust pointers to current token
residual += C * idx;
Expand Down
3 changes: 3 additions & 0 deletions test_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size
*(ptrs[i]) = params_memory_iterator;
params_memory_iterator += param_sizes[i];
}
if(param_sizes[1] == 0) {
params->wlmhead = nullptr;
}
return params_memory;
}

Expand Down
37 changes: 30 additions & 7 deletions train_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ typedef struct {
float norm_eps; // epsilon used in layernorm, e.g. 1e-5
float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3)
bool use_biases; // we always allocate memory for biases; to match llama3 they are not used
bool tied_weights; // untied for large models (3.1 8B/70B/405B), tied for small (3.2 1B/3B)
} LLama3Config;

// the parameters of the model
Expand Down Expand Up @@ -153,7 +154,12 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, LLama3Co
size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated
// now populate the parameter sizes
param_sizes[0] = Vp * C; // wte
param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights)
if(config.tied_weights) {
param_sizes[1] = 0; // no lm_head with tied weights
} else {
param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights)
}

param_sizes[2] = L * C; // ln1w
param_sizes[3] = L * C; // ln1b; (1) all biases are zero it's ok
param_sizes[4] = L * (qkv_channels) * C; // qkvw
Expand Down Expand Up @@ -195,6 +201,10 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen
*(ptrs[i]) = (floatX*)params_memory_iterator;
params_memory_iterator += param_elements[i] * param_sizeof[i];
}
// tied weights?
if(param_elements[1] == 0) {
params->wlmhead = nullptr;
}
return params_memory;
}

Expand Down Expand Up @@ -506,8 +516,9 @@ void llama3_write_to_checkpoint(LLama3 *model, const char* checkpoint_path) {
model_header[7] = model->config.channels;
model_header[8] = model->config.multiple_of;
model_header[9] = model->config.use_scaled_rope;
model_header[10] = 3;
model_header[11] = 1;
model_header[10] = model->config.tied_weights;
model_header[11] = 3;
model_header[12] = model->config.tied_weights ? 2 : 1;
fwriteCheck(model_header, sizeof(int), 256, model_file);
float float_header[256];
float_header[0] = model->config.ffn_dim_multiplier;
Expand Down Expand Up @@ -580,8 +591,9 @@ void llama3_build_from_checkpoint(LLama3 *model, const char* checkpoint_path, bo
model->config.multiple_of = header_int[8];
model->config.use_scaled_rope = header_int[9];
model->config.use_biases = false;
int major_version = header_int[10]; // currently unused, e.g. 3
int minor_version = header_int[11]; // currently unused, e.g. 1 (so Llama 3.1)
model->config.tied_weights = header_int[10];
int major_version = header_int[11]; // currently unused, e.g. 3
int minor_version = header_int[12]; // 1 or 2
// now the float section
model->config.ffn_dim_multiplier = header_float[0];
model->config.norm_eps = header_float[1];
Expand Down Expand Up @@ -740,7 +752,9 @@ void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) {
}
}

matmul_forward_cublaslt(acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream);
floatX* lm_head = model->config.tied_weights ? params.wte : params.wlmhead;
matmul_forward_cublaslt(acts.output, acts.lnf, lm_head, NULL, B, T, C, Vp, main_stream);

cudaCheck(cudaDeviceSynchronize());
}

Expand Down Expand Up @@ -836,7 +850,10 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets,
// technically that is a small, inline backward() pass of calculating
// total, final loss as the mean over all losses over all (B,T) positions in the batch
// next: backward the classifier matmul
matmul_backward(model->acts.scratch_bt4c, grads.wlmhead, NULL, acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream);
floatX* w_lm_head = model->config.tied_weights ? params.wte : params.wlmhead;
floatX* g_lm_head = model->config.tied_weights ? grads.wte : grads.wlmhead;

matmul_backward(model->acts.scratch_bt4c, g_lm_head, NULL, acts.output, acts.lnf, w_lm_head, NULL, B, T, C, Vp, main_stream);
// backward the final layernorm
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream);
Expand Down Expand Up @@ -1076,6 +1093,8 @@ void llama3_update(LLama3 *model, float learning_rate, float beta1, float beta2,
}

ShardInfo tensor = llama3_get_tensor_at_layer(model, 0, i);
if(tensor.size == 0)
continue;
ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1);
ptrdiff_t local_offset_full = tensor.offset + shard.offset;
ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes;
Expand Down Expand Up @@ -1144,6 +1163,10 @@ float llama3_estimate_mfu(LLama3 *model, int num_tokens, float dt) {
second is the attention matmul, which is also usually a small contribution.
*/
size_t N = model->num_parameters;
if(!model->config.tied_weights) {
N -= model->param_elements[0]; // remove embedding parameters, which can be significant at 128k vocab
}

int L = model->config.num_layers;
int C = model->config.channels;
int T = model->seq_len;
Expand Down
27 changes: 19 additions & 8 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def __init__(self, config):
ln_f = RMSNorm(config.n_embd, config.norm_eps),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if config.tied_embeddings:
self.transformer.wte.weight = self.lm_head.weight

# init all weights, use a torch rng object to be very careful
self.init_rng = torch.Generator()
Expand Down Expand Up @@ -433,10 +435,16 @@ def unpermute(w, n_heads, dim1, dim2):
return checkpoint

@classmethod
def from_pretrained_llama3_hf(cls, model_id):
def from_pretrained_llama3_hf(cls, model_id, untie):
"""Loads pretrained LLaMA model weights from HuggingFace"""
from transformers import AutoModelForCausalLM, AutoTokenizer
model_args = MODEL_DICT[model_id]
if untie:
if not model_args.tied_embeddings:
print("Model embeddings are not tied, --untie has no effect.")
else:
print("Untying token embeddings and LM head.")
model_args.tied_embeddings = False

model = AutoModelForCausalLM.from_pretrained(model_id)
checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args)
Expand Down Expand Up @@ -876,7 +884,7 @@ def write_bf16(tensor, file):
b = t.numpy().tobytes()
file.write(b)

def write_tensors(model_tensors, L, file, dtype):
def write_tensors(model_tensors, L, tied, file, dtype):
# writes LLaMA 3 model's weights to a binary file
# things get a bit more complicated though:
# 1) We want to maintain the ability to finetune just the biases in the C code
Expand All @@ -894,7 +902,8 @@ def write_tensors(model_tensors, L, file, dtype):
assert dtype in {"float32", "bfloat16"}
write_fun = write_fp32 if dtype == "float32" else write_bf16
write_fun(model_tensors["transformer.wte.weight"], file) # (V, C)
write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here!
if not tied:
write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here!
for i in range(L): # (L, C)
write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, C)
Expand Down Expand Up @@ -954,8 +963,9 @@ def write_model(model, filename, dtype):
header_int[7] = model.config.n_embd
header_int[8] = model.config.multiple_of
header_int[9] = int(model.config.use_scaled_rope)
header_int[10] = int(model.config.version.split('.')[0]) # major version
header_int[11] = int(model.config.version.split('.')[1]) # minor version
header_int[10] = int(model.config.tied_embeddings)
header_int[11] = int(model.config.version.split('.')[0]) # major version
header_int[12] = int(model.config.version.split('.')[1]) # minor version
# float section of the header
header_float = torch.zeros(256, dtype=torch.float32)
header_float[0] = model.config.ffn_dim_multiplier
Expand All @@ -967,7 +977,7 @@ def write_model(model, filename, dtype):
with open(filename, "wb") as file:
file.write(header_int.numpy().tobytes()) # int header
file.write(header_float.numpy().tobytes()) # float header
write_tensors(params, model.config.n_layer, file, dtype) # params
write_tensors(params, model.config.n_layer, model.config.tied_embeddings, file, dtype) # params
print(f"wrote {filename}")

def write_state(model, x, y, logits, loss, filename):
Expand All @@ -993,7 +1003,7 @@ def write_state(model, x, y, logits, loss, filename):
# loss (single float, result of the cross entropy loss)
write_fp32(loss.cpu(), file)
# gradients
write_tensors(grads, model.config.n_layer, file, "float32")
write_tensors(grads, model.config.n_layer, model.config.tied_embeddings, file, "float32")
print(f"wrote {filename}")


Expand Down Expand Up @@ -1036,6 +1046,7 @@ def print0(*args, **kwargs):
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model")
parser.add_argument("--depth", type=int, default=-1, help="load only a subset of the model's layers")
parser.add_argument("--untie", type=int, default=False, help="Untie token embeddings and LM-head, even if they are tied in the checkpoint.")
# token layout for each step of the optimization
parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
Expand Down Expand Up @@ -1140,7 +1151,7 @@ def print0(*args, **kwargs):

# init the model
if args.use_hf:
model = LLaMA.from_pretrained_llama3_hf(args.model)
model = LLaMA.from_pretrained_llama3_hf(args.model, args.untie)
else: # use Meta's checkpoint
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
Expand Down