Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 33 additions & 39 deletions ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.33.0\n",
"huggingface_hub version: 0.33.2\n",
"sentencepiece version: 0.2.0\n",
"torch version: 2.6.0\n"
]
Expand Down Expand Up @@ -1306,22 +1306,7 @@
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
"outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "66e777955e8748df878f118f07f38dab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"weights_file = hf_hub_download(\n",
" repo_id=\"meta-llama/Llama-2-7b\",\n",
Expand Down Expand Up @@ -1405,7 +1390,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 32,
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
"metadata": {
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
Expand All @@ -1422,19 +1407,40 @@
" return torch.nn.Parameter(torch.tensor(right))\n",
"\n",
"\n",
"def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",
" return (w.view(n_heads, out_dim // n_heads // 2, 2, in_dim)\n",
" .transpose(1, 2) # put axis 2 next to heads\n",
" .reshape(out_dim, in_dim))\n",
"\n",
"\n",
"def load_weights_into_llama(model, param_config, params):\n",
"\n",
" cfg = LLAMA2_CONFIG_7B\n",
" \n",
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
"\n",
" for l in range(param_config[\"n_layers\"]):\n",
"\n",
" # Load attention weights\n",
" # The original Meta/Llama checkpoints store Q and K so that the two numbers \n",
" # that form one complex RoPE pair sit next to each other inside the head dimension (\"sliced\" layout).\n",
" # Our RoPE implementation, similar to the one in Hugging Face, expects an interleaved layout\n",
" # For example, with n_heads=2 and head_dim = 8\n",
" # ┌── pair 0 ──┐ ┌── pair 1 ──┐\n",
" # Meta (sliced): [ h0: r0 r1 r2 r3, h1: r0 r1 r2 r3 ]\n",
" # Ours & HF (interleaved): [ h0: r0 r0 r1 r1 r2 r2 r3 r3 , h1: ... ]\n",
" # For more information, please see the discussion in the PR: https://github.com/rasbt/LLMs-from-scratch/pull/747 \n",
" \n",
" # So, below, for q_raw and k_raw, we must re‑order the checkpoint weights using the slices_to_interleave helper\n",
"\n",
" q_raw = params[f\"layers.{l}.attention.wq.weight\"]\n",
" model.trf_blocks[l].att.W_query.weight = assign(\n",
" model.trf_blocks[l].att.W_query.weight,\n",
" params[f\"layers.{l}.attention.wq.weight\"]\n",
" permute(q_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
" )\n",
" k_raw = params[f\"layers.{l}.attention.wk.weight\"]\n",
" model.trf_blocks[l].att.W_key.weight = assign(\n",
" model.trf_blocks[l].att.W_key.weight,\n",
" params[f\"layers.{l}.attention.wk.weight\"]\n",
" permute(k_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
" )\n",
" model.trf_blocks[l].att.W_value.weight = assign(\n",
" model.trf_blocks[l].att.W_value.weight,\n",
Expand Down Expand Up @@ -1489,7 +1495,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 33,
"id": "240987e8-a023-462e-9376-9edfb27559ec",
"metadata": {
"colab": {
Expand All @@ -1504,7 +1510,7 @@
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
" Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not\n"
]
}
],
Expand Down Expand Up @@ -1544,7 +1550,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 35,
"id": "nbvAV7vaz6yc",
"metadata": {
"colab": {
Expand All @@ -1568,27 +1574,14 @@
"outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" What do llamas eat?\n",
"Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
"\n",
"Llamas are herbivores, which means they eat plants for their food. They feed on a variety\n"
]
}
],
Expand All @@ -1601,6 +1594,7 @@
" local_dir=\"Llama-2-7b-chat\"\n",
")\n",
"\n",
"weights = torch.load(weights_file, weights_only=True)\n",
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
"load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
"model.to(device);\n",
Expand Down
Loading