Skip to content

Commit dc763fb

Browse files
committed
Improve MHA einsum
1 parent 80d4732 commit dc763fb

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"name": "stdout",
5959
"output_type": "stream",
6060
"text": [
61-
"PyTorch version: 2.6.0+cu124\n"
61+
"PyTorch version: 2.8.0\n"
6262
]
6363
}
6464
],
@@ -89,7 +89,7 @@
8989
},
9090
{
9191
"cell_type": "code",
92-
"execution_count": null,
92+
"execution_count": 2,
9393
"id": "1db27f43-86f4-478f-89df-fbc2182a129b",
9494
"metadata": {
9595
"id": "1db27f43-86f4-478f-89df-fbc2182a129b"
@@ -114,7 +114,7 @@
114114
},
115115
{
116116
"cell_type": "code",
117-
"execution_count": 2,
117+
"execution_count": 3,
118118
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
119119
"metadata": {
120120
"colab": {
@@ -205,7 +205,7 @@
205205
},
206206
{
207207
"cell_type": "code",
208-
"execution_count": 3,
208+
"execution_count": 4,
209209
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
210210
"metadata": {
211211
"colab": {
@@ -326,7 +326,7 @@
326326
},
327327
{
328328
"cell_type": "code",
329-
"execution_count": 4,
329+
"execution_count": 5,
330330
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
331331
"metadata": {
332332
"colab": {
@@ -434,7 +434,7 @@
434434
},
435435
{
436436
"cell_type": "code",
437-
"execution_count": 5,
437+
"execution_count": 6,
438438
"id": "92481814-068d-439b-a65c-b1310ebbe0aa",
439439
"metadata": {
440440
"colab": {
@@ -466,7 +466,6 @@
466466
" self.num_heads = num_heads\n",
467467
" self.head_dim = d_out // num_heads\n",
468468
"\n",
469-
" # Initialize parameters for Q, K, V\n",
470469
" self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n",
471470
" self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n",
472471
" self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n",
@@ -483,8 +482,6 @@
483482
" self.out_proj = nn.Linear(d_out, d_out)\n",
484483
" self.dropout = nn.Dropout(dropout)\n",
485484
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
486-
"\n",
487-
" # Initialize parameters\n",
488485
" self.reset_parameters()\n",
489486
"\n",
490487
"\n",

0 commit comments

Comments
 (0)