Skip to content

Commit fc101b7

Browse files
rasbtmissflash
andauthored
Added Apple Silicon GPU device update (#820)
* Added Apple Silicon GPU device * Added Apple Silicon GPU device * delete: remove unused model.pth file from understanding-buffers * update * update --------- Co-authored-by: missflash <[email protected]>
1 parent 8e17031 commit fc101b7

File tree

2 files changed

+834
-813
lines changed

2 files changed

+834
-813
lines changed

ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"name": "stdout",
5959
"output_type": "stream",
6060
"text": [
61+
"Using device: cuda\n",
6162
"PyTorch version: 2.8.0\n"
6263
]
6364
}
@@ -66,7 +67,14 @@
6667
"import torch\n",
6768
"\n",
6869
"torch.manual_seed(123)\n",
69-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
70+
"if torch.backends.mps.is_available():\n",
71+
" device = torch.device(\"mps\") # Apple Silicon GPU (Metal)\n",
72+
"elif torch.cuda.is_available():\n",
73+
" device = torch.device(\"cuda\") # NVIDIA GPU\n",
74+
"else:\n",
75+
" device = torch.device(\"cpu\") # CPU fallback\n",
76+
"\n",
77+
"print(f\"Using device: {device}\")\n",
7078
"print(f\"PyTorch version: {torch.__version__}\")\n",
7179
"\n",
7280
"batch_size = 8\n",

0 commit comments

Comments
 (0)