|
58 | 58 | "name": "stdout",
|
59 | 59 | "output_type": "stream",
|
60 | 60 | "text": [
|
61 |
| - "PyTorch version: 2.6.0+cu124\n" |
| 61 | + "PyTorch version: 2.8.0\n" |
62 | 62 | ]
|
63 | 63 | }
|
64 | 64 | ],
|
|
89 | 89 | },
|
90 | 90 | {
|
91 | 91 | "cell_type": "code",
|
92 |
| - "execution_count": null, |
| 92 | + "execution_count": 2, |
93 | 93 | "id": "1db27f43-86f4-478f-89df-fbc2182a129b",
|
94 | 94 | "metadata": {
|
95 | 95 | "id": "1db27f43-86f4-478f-89df-fbc2182a129b"
|
|
114 | 114 | },
|
115 | 115 | {
|
116 | 116 | "cell_type": "code",
|
117 |
| - "execution_count": 2, |
| 117 | + "execution_count": 3, |
118 | 118 | "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
|
119 | 119 | "metadata": {
|
120 | 120 | "colab": {
|
|
205 | 205 | },
|
206 | 206 | {
|
207 | 207 | "cell_type": "code",
|
208 |
| - "execution_count": 3, |
| 208 | + "execution_count": 4, |
209 | 209 | "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
|
210 | 210 | "metadata": {
|
211 | 211 | "colab": {
|
|
326 | 326 | },
|
327 | 327 | {
|
328 | 328 | "cell_type": "code",
|
329 |
| - "execution_count": 4, |
| 329 | + "execution_count": 5, |
330 | 330 | "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
|
331 | 331 | "metadata": {
|
332 | 332 | "colab": {
|
|
434 | 434 | },
|
435 | 435 | {
|
436 | 436 | "cell_type": "code",
|
437 |
| - "execution_count": 5, |
| 437 | + "execution_count": 6, |
438 | 438 | "id": "92481814-068d-439b-a65c-b1310ebbe0aa",
|
439 | 439 | "metadata": {
|
440 | 440 | "colab": {
|
|
466 | 466 | " self.num_heads = num_heads\n",
|
467 | 467 | " self.head_dim = d_out // num_heads\n",
|
468 | 468 | "\n",
|
469 |
| - " # Initialize parameters for Q, K, V\n", |
470 | 469 | " self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n",
|
471 | 470 | " self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n",
|
472 | 471 | " self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n",
|
|
483 | 482 | " self.out_proj = nn.Linear(d_out, d_out)\n",
|
484 | 483 | " self.dropout = nn.Dropout(dropout)\n",
|
485 | 484 | " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
486 |
| - "\n", |
487 |
| - " # Initialize parameters\n", |
488 | 485 | " self.reset_parameters()\n",
|
489 | 486 | "\n",
|
490 | 487 | "\n",
|
|
0 commit comments