Skip to content

Commit 296b9d3

Browse files
committed
Improve weight tying handling
1 parent 1412b13 commit 296b9d3

11 files changed

+548
-177
lines changed

ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,14 +1400,17 @@
14001400
},
14011401
"outputs": [],
14021402
"source": [
1403-
"def assign(left, right):\n",
1403+
"def assign(left, right, tensor_name=\"unknown\"):\n",
14041404
" if left.shape != right.shape:\n",
1405-
" raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n",
1405+
" raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
1406+
" \n",
1407+
" with torch.no_grad():\n",
1408+
" if isinstance(right, torch.Tensor):\n",
1409+
" left.copy_(right)\n",
1410+
" else:\n",
1411+
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
14061412
"\n",
1407-
" if isinstance(right, torch.Tensor):\n",
1408-
" return torch.nn.Parameter(right.clone().detach())\n",
1409-
" else:\n",
1410-
" return torch.nn.Parameter(torch.tensor(right))\n",
1413+
" return left \n",
14111414
"\n",
14121415
"\n",
14131416
"def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",

0 commit comments

Comments
 (0)