Skip to content

Commit f54abc5

Browse files
authored
Merge pull request #13123 from Lightning-AI/mps_accelerator
MPS Accelerator
1 parent 5572797 commit f54abc5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+812
-163
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7878
- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330))
7979

8080

81+
- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224)
8182

8283

83-
- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224))
84+
- Added Apple Silicon Support via `MPSAccelerator` ([#13123](https://github.com/PyTorchLightning/pytorch-lightning/pull/13123))
85+
8486

8587
### Changed
8688

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
.. _mps:
2+
3+
Accelerator: Apple Silicon training
4+
===================================
5+
6+
.. raw:: html
7+
8+
<div class="display-card-container">
9+
<div class="row">
10+
11+
.. Add callout items below this line
12+
13+
.. displayitem::
14+
:header: Prepare your code (Optional)
15+
:description: Prepare your code to run on any hardware
16+
:col_css: col-md-4
17+
:button_link: accelerator_prepare.html
18+
:height: 150
19+
:tag: basic
20+
21+
.. displayitem::
22+
:header: Basic
23+
:description: Learn the basics of Apple silicon gpu training.
24+
:col_css: col-md-4
25+
:button_link: mps_basic.html
26+
:height: 150
27+
:tag: basic
28+
29+
.. raw:: html
30+
31+
</div>
32+
</div>
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
:orphan:
2+
3+
.. _mps_basic:
4+
5+
MPS training (basic)
6+
====================
7+
**Audience:** Users looking to train on their Apple silicon GPUs.
8+
9+
.. warning::
10+
11+
Both the MPS accelerator and the PyTorch backend are still experimental.
12+
As such, not all operations are currently supported. However, with ongoing development from the PyTorch team, an increasingly large number of operations are becoming available.
13+
You can use ``PYTORCH_ENABLE_MPS_FALLBACK=1 python your_script.py`` to fall back to cpu for unsupported operations.
14+
15+
16+
----
17+
18+
What is Apple silicon?
19+
----------------------
20+
Apple silicon chips are a unified system on a chip (SoC) developed by Apple based on the ARM design.
21+
Among other things, they feature CPU-cores, GPU-cores, a neural engine and shared memory between all of these features.
22+
23+
----
24+
25+
So it's a CPU?
26+
--------------
27+
Apple silicon includes CPU-cores among several other features. However, the full potential for the hardware acceleration of which the M-Socs are capable is unavailable when running on the ``CPUAccelerator``. This is because they also feature a GPU and a neural engine.
28+
29+
To use them, Lightning supports the ``MPSAccelerator``.
30+
31+
----
32+
33+
Run on Apple silicon gpus
34+
-------------------------
35+
Enable the following Trainer arguments to run on Apple silicon gpus (MPS devices).
36+
37+
.. code::
38+
39+
trainer = Trainer(accelerator="mps", devices=1)
40+
41+
.. note::
42+
The ``MPSAccelerator`` only supports 1 device at a time. Currently there are no machines with multiple MPS-capable GPUs.
43+
44+
----
45+
46+
What does MPS stand for?
47+
------------------------
48+
MPS is short for `Metal Performance Shaders <https://developer.apple.com/metal/>`_ which is the technology used in the back for gpu communication and computing.

docs/source-pytorch/extensions/accelerator.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
Accelerator
55
###########
66

7-
The Accelerator connects a Lightning Trainer to arbitrary hardware (CPUs, GPUs, TPUs, IPUs, ...).
7+
The Accelerator connects a Lightning Trainer to arbitrary hardware (CPUs, GPUs, TPUs, IPUs, MPS, ...).
88
Currently there are accelerators for:
99

1010
- CPU
1111
- :doc:`GPU <../accelerators/gpu>`
1212
- :doc:`TPU <../accelerators/tpu>`
1313
- :doc:`IPU <../accelerators/ipu>`
1414
- :doc:`HPU <../accelerators/hpu>`
15+
- :doc:`MPS <../accelerators/mps>`
1516

1617
The Accelerator is part of the Strategy which manages communication across multiple devices (distributed communication).
1718
Whenever the Trainer, the loops or any other component in Lightning needs to talk to hardware, it calls into the Strategy and the Strategy calls into the Accelerator.
@@ -127,4 +128,5 @@ Accelerator API
127128
GPUAccelerator
128129
HPUAccelerator
129130
IPUAccelerator
131+
MPSAccelerator
130132
TPUAccelerator

docs/source-pytorch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ Current Lightning Users
210210
Train on single or multiple HPUs <accelerators/hpu>
211211
Train on single or multiple IPUs <accelerators/ipu>
212212
Train on single or multiple TPUs <accelerators/tpu>
213+
Train on MPS <accelerators/mps>
213214
Use a pretrained model <advanced/pretrained>
214215
model/own_your_loop
215216

src/pytorch_lightning/accelerators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
1616
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
1717
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
18+
from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401
1819
from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401
1920
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401
2021

src/pytorch_lightning/accelerators/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
7474
@staticmethod
7575
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
7676
"""Accelerator device parsing logic."""
77-
return device_parser.parse_gpu_ids(devices)
77+
return device_parser.parse_gpu_ids(devices, include_cuda=True)
7878

7979
@staticmethod
8080
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Dict, List, Optional, Union
15+
16+
import torch
17+
18+
from pytorch_lightning.accelerators.accelerator import Accelerator
19+
from pytorch_lightning.utilities import device_parser
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21+
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
22+
from pytorch_lightning.utilities.types import _DEVICE
23+
24+
_MPS_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available()
25+
26+
27+
class MPSAccelerator(Accelerator):
28+
"""Accelerator for Metal Apple Silicon GPU devices."""
29+
30+
def setup_environment(self, root_device: torch.device) -> None:
31+
"""
32+
Raises:
33+
MisconfigurationException:
34+
If the selected device is not MPS.
35+
"""
36+
super().setup_environment(root_device)
37+
if root_device.type != "mps":
38+
raise MisconfigurationException(f"Device should be MPS, got {root_device} instead.")
39+
40+
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
41+
"""Get M1 (cpu + gpu) stats from ``psutil`` package."""
42+
return get_device_stats()
43+
44+
@staticmethod
45+
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
46+
"""Accelerator device parsing logic."""
47+
parsed_devices = device_parser.parse_gpu_ids(devices, include_mps=True)
48+
return parsed_devices
49+
50+
@staticmethod
51+
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
52+
"""Gets parallel devices for the Accelerator."""
53+
parsed_devices = MPSAccelerator.parse_devices(devices)
54+
assert parsed_devices is not None
55+
56+
return [torch.device("mps", i) for i in range(len(parsed_devices))]
57+
58+
@staticmethod
59+
def auto_device_count() -> int:
60+
"""Get the devices when set to auto."""
61+
return 1
62+
63+
@staticmethod
64+
def is_available() -> bool:
65+
"""MPS is only available for certain torch builds starting at torch>=1.12."""
66+
return _MPS_AVAILABLE
67+
68+
@classmethod
69+
def register_accelerators(cls, accelerator_registry: Dict) -> None:
70+
accelerator_registry.register(
71+
"mps",
72+
cls,
73+
description=cls.__class__.__name__,
74+
)
75+
76+
77+
# device metrics
78+
_VM_PERCENT = "M1_vm_percent"
79+
_PERCENT = "M1_percent"
80+
_SWAP_PERCENT = "M1_swap_percent"
81+
82+
83+
def get_device_stats() -> Dict[str, float]:
84+
if not _PSUTIL_AVAILABLE:
85+
raise ModuleNotFoundError(
86+
"Fetching M1 device stats requires `psutil` to be installed."
87+
" Install it by running `pip install -U psutil`."
88+
)
89+
import psutil
90+
91+
return {
92+
_VM_PERCENT: psutil.virtual_memory().percent,
93+
_PERCENT: psutil.cpu_percent(),
94+
_SWAP_PERCENT: psutil.swap_memory().percent,
95+
}

src/pytorch_lightning/lite/lite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def _supported_device_types() -> Sequence[_AcceleratorType]:
466466
_AcceleratorType.CPU,
467467
_AcceleratorType.GPU,
468468
_AcceleratorType.TPU,
469+
_AcceleratorType.MPS,
469470
)
470471

471472
@staticmethod

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2626
from pytorch_lightning.accelerators.hpu import HPUAccelerator
2727
from pytorch_lightning.accelerators.ipu import IPUAccelerator
28+
from pytorch_lightning.accelerators.mps import MPSAccelerator
2829
from pytorch_lightning.accelerators.registry import AcceleratorRegistry
2930
from pytorch_lightning.accelerators.tpu import TPUAccelerator
3031
from pytorch_lightning.plugins import (
@@ -178,7 +179,7 @@ def __init__(
178179
self._precision_flag: Optional[Union[int, str]] = None
179180
self._precision_plugin_flag: Optional[PrecisionPlugin] = None
180181
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
181-
self._parallel_devices: List[Union[int, torch.device]] = []
182+
self._parallel_devices: List[Union[int, torch.device, str]] = []
182183
self._layer_sync: Optional[LayerSync] = NativeSyncBatchNorm() if sync_batchnorm else None
183184
self.checkpoint_io: Optional[CheckpointIO] = None
184185
self._amp_type_flag: Optional[LightningEnum] = None
@@ -407,7 +408,7 @@ def _check_device_config_and_set_final_flags(
407408
if self._devices_flag == "auto" and self._accelerator_flag is None:
408409
raise MisconfigurationException(
409410
f"You passed `devices={devices}` but haven't specified"
410-
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu)` for the devices mapping."
411+
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu'|'mps')` for the devices mapping."
411412
)
412413

413414
def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
@@ -484,6 +485,8 @@ def _choose_accelerator(self) -> str:
484485
return "ipu"
485486
if _HPU_AVAILABLE:
486487
return "hpu"
488+
if MPSAccelerator.is_available():
489+
return "mps"
487490
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
488491
return "gpu"
489492
return "cpu"
@@ -571,11 +574,13 @@ def _choose_strategy(self) -> Union[Strategy, str]:
571574
if self._num_nodes_flag > 1:
572575
return DDPStrategy.strategy_name
573576
if len(self._parallel_devices) <= 1:
574-
device = (
575-
device_parser.determine_root_gpu_device(self._parallel_devices) # type: ignore
576-
if self._accelerator_flag == "gpu"
577-
else "cpu"
578-
)
577+
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
578+
if isinstance(self._accelerator_flag, (GPUAccelerator, MPSAccelerator)) or (
579+
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("gpu", "mps")
580+
):
581+
device = device_parser.determine_root_gpu_device(self._parallel_devices)
582+
else:
583+
device = "cpu"
579584
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_device"
580585
return SingleDeviceStrategy(device=device) # type: ignore
581586
if len(self._parallel_devices) > 1:

0 commit comments

Comments
 (0)