diff --git a/docs/source-pytorch/accelerators/gpu_intermediate.rst b/docs/source-pytorch/accelerators/gpu_intermediate.rst index 930d4654a4c02..dbd2dfd790bd6 100644 --- a/docs/source-pytorch/accelerators/gpu_intermediate.rst +++ b/docs/source-pytorch/accelerators/gpu_intermediate.rst @@ -21,8 +21,10 @@ Lightning supports multiple ways of doing distributed training. | - Data Parallel (``strategy='dp'``) (multiple-gpus, 1 machine) -- DistributedDataParallel (``strategy='ddp'``) (multiple-gpus across many machines (python script based)). -- DistributedDataParallel (``strategy='ddp_spawn'``) (multiple-gpus across many machines (spawn based)). +- DistributedDataParallel (multiple-gpus across many machines) + - Regular (``strategy='ddp'``) + - Spawn (``strategy='ddp_spawn'``) + - Fork (``strategy='ddp_fork'``) - Horovod (``strategy='horovod'``) (multi-machine, multi-gpu, configured at runtime) - Bagua (``strategy='bagua'``) (multiple-gpus across many machines with advanced training algorithms) @@ -199,6 +201,61 @@ You can then call your scripts anywhere python some_file.py --accelerator 'gpu' --devices 8 --strategy 'ddp' +Distributed Data Parallel Fork +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +DDP Fork is an alternative to Spawn that can be used in interactive Python and Jupyter notebooks, Google Colab, Kaggle notebooks, and so on: + +.. code-block:: python + + # train on 8 GPUs in a Jupyter notebook + trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp_fork") + +Data Parallel (``strategy="dp"``) is the only other strategy supported in interactive environments but is slower, is discouraged by PyTorch and has other limitations. +Among the native distributed strategies, regular DDP (``strategy="ddp"``) is still recommended as the go-to strategy over Spawn and Fork for its speed and stability but it can only be used with scripts. + + +Comparison of DDP variants and tradeoffs +**************************************** + +.. list-table:: DDP variants and their tradeoffs + :widths: 40 20 20 20 + :header-rows: 1 + + * - + - DDP + - DDP Spawn + - DDP Fork + * - Works in Jupyter notebooks / IPython environments + - No + - No + - Yes + * - Supports multi-node + - Yes + - Yes + - Yes + * - Supported platforms + - Linux, Mac, Win + - Linux, Mac, Win + - Linux, Mac + * - Requires all objects to be picklable + - No + - Yes + - No + * - Is the guard ``if __name__=="__main__"`` required? + - Yes + - Yes + - No + * - Limitations in the main process + - None + - None + - GPU operations such as moving tensors to the GPU or calling ``torch.cuda`` functions before invoking ``Trainer.fit`` is not allowed. + * - Process creation time + - Slow + - Slow + - Fast + + Horovod ^^^^^^^ `Horovod `_ allows the same training script to be used for single-GPU, diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index a6388429504b4..ae635a10c6051 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -99,6 +99,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Apple Silicon Support via `MPSAccelerator` ([#13123](https://github.com/PyTorchLightning/pytorch-lightning/pull/13123)) +- Added support for DDP Fork ([#13405](https://github.com/PyTorchLightning/pytorch-lightning/pull/13405)) + ### Changed diff --git a/src/pytorch_lightning/accelerators/cuda.py b/src/pytorch_lightning/accelerators/cuda.py index 89d1a5b284b0c..a474ef9a99031 100644 --- a/src/pytorch_lightning/accelerators/cuda.py +++ b/src/pytorch_lightning/accelerators/cuda.py @@ -52,7 +52,7 @@ def setup(self, trainer: "pl.Trainer") -> None: def set_nvidia_flags(local_rank: int) -> None: # set the correct cuda visible devices (using pci order) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count())) + all_gpu_ids = ",".join(str(x) for x in range(device_parser.num_cuda_devices())) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") @@ -84,11 +84,11 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]: @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" - return torch.cuda.device_count() + return device_parser.num_cuda_devices() @staticmethod def is_available() -> bool: - return torch.cuda.device_count() > 0 + return device_parser.num_cuda_devices() > 0 @classmethod def register_accelerators(cls, accelerator_registry: Dict) -> None: @@ -162,6 +162,6 @@ def _to_float(x: str) -> float: def _get_gpu_id(device_id: int) -> str: """Get the unmasked real GPU IDs.""" # All devices if `CUDA_VISIBLE_DEVICES` unset - default = ",".join(str(i) for i in range(torch.cuda.device_count())) + default = ",".join(str(i) for i in range(device_parser.num_cuda_devices())) cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") return cuda_visible_devices[device_id].strip() diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index f5cebcd89a63d..86bddaf676e01 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -468,6 +468,7 @@ def _supported_strategy_types() -> Sequence[_StrategyType]: _StrategyType.DP, _StrategyType.DDP, _StrategyType.DDP_SPAWN, + _StrategyType.DDP_FORK, _StrategyType.DEEPSPEED, _StrategyType.DDP_SHARDED, _StrategyType.DDP_SHARDED_SPAWN, diff --git a/src/pytorch_lightning/profilers/pytorch.py b/src/pytorch_lightning/profilers/pytorch.py index c9340444a06eb..079aafe37ec8b 100644 --- a/src/pytorch_lightning/profilers/pytorch.py +++ b/src/pytorch_lightning/profilers/pytorch.py @@ -24,6 +24,7 @@ from torch.autograd.profiler import record_function from pytorch_lightning.profilers.profiler import Profiler +from pytorch_lightning.utilities.device_parser import is_cuda_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -368,7 +369,7 @@ def _default_activities(self) -> List["ProfilerActivity"]: return activities if self._profiler_kwargs.get("use_cpu", True): activities.append(ProfilerActivity.CPU) - if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): + if self._profiler_kwargs.get("use_cuda", is_cuda_available()): activities.append(ProfilerActivity.CUDA) return activities diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index e3c37aa2f2ff1..a0b0a7865869f 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -22,6 +22,7 @@ from torch.distributed.constants import default_pg_timeout from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel +from typing_extensions import Literal import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule @@ -71,6 +72,7 @@ def __init__( ddp_comm_wrapper: Optional[callable] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, + start_method: Literal["spawn", "fork", "forkserver"] = "spawn", **kwargs: Any, ): super().__init__( @@ -88,6 +90,7 @@ def __init__( self._local_rank = 0 self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout + self._start_method = start_method @property def num_nodes(self) -> int: @@ -124,7 +127,7 @@ def process_group_backend(self) -> Optional[str]: return self._process_group_backend def _configure_launcher(self): - self._launcher = _SpawnLauncher(self) + self._launcher = _SpawnLauncher(self, start_method=self._start_method) def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) @@ -280,17 +283,20 @@ def post_training_step(self): @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - strategy_registry.register( - "ddp_spawn_find_unused_parameters_false", - cls, - description="DDPSpawn Strategy with `find_unused_parameters` as False", - find_unused_parameters=False, - ) - strategy_registry.register( - cls.strategy_name, - cls, - description=f"{cls.__class__.__name__}", - ) + for start_method in ("spawn", "fork"): + strategy_registry.register( + f"ddp_{start_method}_find_unused_parameters_false", + cls, + description=f"DDP {start_method.title()} strategy with `find_unused_parameters` as False", + find_unused_parameters=False, + start_method=start_method, + ) + strategy_registry.register( + f"ddp_{start_method}", + cls, + description=f"DDP {start_method.title()} strategy", + start_method=start_method, + ) def teardown(self) -> None: log.detail(f"{self.__class__.__name__}: tearing down strategy") diff --git a/src/pytorch_lightning/strategies/launchers/spawn.py b/src/pytorch_lightning/strategies/launchers/spawn.py index 0a92ceee5aacf..91482b66e5de3 100644 --- a/src/pytorch_lightning/strategies/launchers/spawn.py +++ b/src/pytorch_lightning/strategies/launchers/spawn.py @@ -20,6 +20,7 @@ import torch import torch.multiprocessing as mp from torch import Tensor +from typing_extensions import Literal import pytorch_lightning as pl from pytorch_lightning.strategies.launchers.base import _Launcher @@ -34,27 +35,40 @@ class _SpawnLauncher(_Launcher): r"""Spawns processes that run a given function in parallel, and joins them all at the end. The main process in which this launcher is invoked creates N so-called worker processes (using - :func:`torch.multiprocessing.spawn`) that run the given function. + :func:`torch.multiprocessing.start_processes`) that run the given function. Worker processes have a rank that ranges from 0 to N - 1. Note: - This launcher requires all objects to be pickleable. - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + - With start method 'fork' the user must ensure that no CUDA context gets created in the main process before + the launcher is invoked. E.g., one should avoid creating cuda tensors or calling ``torch.cuda.*`` functions + before calling ``Trainer.fit``. Args: strategy: A reference to the strategy that is used together with this launcher. + start_method: The method how to start the processes. + - 'spawn': The default start method. Requires all objects to be pickleable. + - 'fork': Preferrable for IPython/Jupyter environments where 'spawn' is not available. Not available on + the Windows platform for example. + - 'forkserver': Alternative implementation to 'fork'. """ - def __init__(self, strategy: Strategy) -> None: + def __init__(self, strategy: Strategy, start_method: Literal["spawn", "fork", "forkserver"] = "spawn") -> None: self._strategy = strategy - self._start_method = "spawn" + self._start_method = start_method + if start_method not in mp.get_all_start_methods(): + raise ValueError( + f"The start method '{self._start_method}' is not available on this platform. Available methods are:" + f" {', '.join(mp.get_all_start_methods())}" + ) @property def is_interactive_compatible(self) -> bool: - # The start method 'spawn' is currently the only one that works with DDP and CUDA support - # The start method 'fork' is the only one supported in Jupyter environments but not compatible with CUDA - # For more context, see https://github.com/Lightning-AI/lightning/issues/7550 - return self._start_method == "fork" and self._strategy.root_device.type != "cuda" + # The start method 'spawn' is not supported in interactive environments + # The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA + # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550 + return self._start_method == "fork" def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. @@ -75,7 +89,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) context = mp.get_context(self._start_method) return_queue = context.SimpleQueue() - mp.spawn( + mp.start_processes( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), nprocs=self._strategy.num_processes, diff --git a/src/pytorch_lightning/strategies/launchers/xla_spawn.py b/src/pytorch_lightning/strategies/launchers/xla_spawn.py index 9c47e3b325cac..9a3028840b142 100644 --- a/src/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/src/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -50,8 +50,7 @@ class _XLASpawnLauncher(_SpawnLauncher): """ def __init__(self, strategy: "Strategy") -> None: - super().__init__(strategy) - self._start_method = "fork" + super().__init__(strategy=strategy, start_method="fork") @property def is_interactive_compatible(self) -> bool: diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 3a30677a5b87b..1f7effcf75008 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -70,9 +70,9 @@ def __init__( cluster_environment=XLAEnvironment(), checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, + start_method="fork", ) self.debug = debug - self.start_method = "fork" @property def checkpoint_io(self) -> CheckpointIO: @@ -123,7 +123,6 @@ def _configure_launcher(self): self._launcher = _XLASpawnLauncher(self) def setup(self, trainer: "pl.Trainer") -> None: - self.start_method = "fork" self.accelerator.setup(trainer) if self.debug: diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index f72aba305e8b9..9f87a68b4df7d 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -267,7 +267,7 @@ def _check_config_and_set_final_flags( if strategy == "ddp_cpu": raise MisconfigurationException( "`Trainer(strategy='ddp_cpu')` is not a valid strategy," - " you can use `Trainer(strategy='ddp'|'ddp_spawn', accelerator='cpu')` instead." + " you can use `Trainer(strategy='ddp'|'ddp_spawn'|'ddp_fork', accelerator='cpu')` instead." ) if strategy == "tpu_spawn": raise MisconfigurationException( @@ -496,7 +496,7 @@ def _choose_accelerator(self) -> str: return "hpu" if MPSAccelerator.is_available(): return "mps" - if torch.cuda.is_available() and torch.cuda.device_count() > 0: + if CUDAAccelerator.is_available(): return "cuda" return "cpu" @@ -614,7 +614,14 @@ def _check_strategy_and_fallback(self) -> None: f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, " "but GPU accelerator is not used." ) - + if ( + strategy_flag in ("ddp_fork", "ddp_fork_find_unused_parameters_false") + and "fork" not in torch.multiprocessing.get_all_start_methods() + ): + raise ValueError( + f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this" + f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead." + ) if strategy_flag: self._strategy_flag = strategy_flag diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 46e991d1bbbab..d10225fea2d65 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1759,7 +1759,7 @@ def _log_device_info(self) -> None: rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs") # TODO: Integrate MPS Accelerator here, once gpu maps to both - if torch.cuda.is_available() and not isinstance(self.accelerator, CUDAAccelerator): + if CUDAAccelerator.is_available() and not isinstance(self.accelerator, CUDAAccelerator): rank_zero_warn( "GPU available but not used. Set `accelerator` and `devices` using" f" `Trainer(accelerator='gpu', devices={CUDAAccelerator.auto_device_count()})`.", diff --git a/src/pytorch_lightning/tuner/auto_gpu_select.py b/src/pytorch_lightning/tuner/auto_gpu_select.py index d87eba64494f0..a42e55a61321d 100644 --- a/src/pytorch_lightning/tuner/auto_gpu_select.py +++ b/src/pytorch_lightning/tuner/auto_gpu_select.py @@ -15,6 +15,7 @@ import torch +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -31,7 +32,7 @@ def pick_multiple_gpus(nb: int) -> List[int]: " Please select a valid number of GPU resources when using auto_select_gpus." ) - num_gpus = torch.cuda.device_count() + num_gpus = device_parser.num_cuda_devices() if nb > num_gpus: raise MisconfigurationException(f"You requested {nb} GPUs but your machine only has {num_gpus} GPUs.") nb = num_gpus if nb == -1 else nb @@ -51,7 +52,7 @@ def pick_single_gpu(exclude_gpus: List[int]) -> int: """ previously_used_gpus = [] unused_gpus = [] - for i in range(torch.cuda.device_count()): + for i in range(device_parser.num_cuda_devices()): if i in exclude_gpus: continue diff --git a/src/pytorch_lightning/utilities/device_parser.py b/src/pytorch_lightning/utilities/device_parser.py index 881a02a809ec2..c76933e489db7 100644 --- a/src/pytorch_lightning/utilities/device_parser.py +++ b/src/pytorch_lightning/utilities/device_parser.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing from typing import Any, List, MutableSequence, Optional, Tuple, Union import torch +import torch.cuda from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus @@ -250,7 +252,7 @@ def _get_all_available_cuda_gpus() -> List[int]: Returns: a list of all available CUDA gpus """ - return list(range(torch.cuda.device_count())) + return list(range(num_cuda_devices())) def _check_unique(device_ids: List[int]) -> None: @@ -330,3 +332,27 @@ def parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]: raise MisconfigurationException("`devices` for `HPUAccelerator` must be int, string or None.") return int(devices) if isinstance(devices, str) else devices + + +def num_cuda_devices() -> int: + """Returns the number of GPUs available. + + Unlike :func:`torch.cuda.device_count`, this function will do its best not to create a CUDA context for fork + support, if the platform allows it. + """ + if "fork" not in torch.multiprocessing.get_all_start_methods(): + return torch.cuda.device_count() + with multiprocessing.get_context("fork").Pool(1) as pool: + return pool.apply(torch.cuda.device_count) + + +def is_cuda_available() -> bool: + """Returns a bool indicating if CUDA is currently available. + + Unlike :func:`torch.cuda.is_available`, this function will do its best not to create a CUDA context for fork + support, if the platform allows it. + """ + if "fork" not in torch.multiprocessing.get_all_start_methods(): + return torch.cuda.is_available() + with multiprocessing.get_context("fork").Pool(1) as pool: + return pool.apply(torch.cuda.is_available) diff --git a/src/pytorch_lightning/utilities/enums.py b/src/pytorch_lightning/utilities/enums.py index b7f714d230971..91f8466b77500 100644 --- a/src/pytorch_lightning/utilities/enums.py +++ b/src/pytorch_lightning/utilities/enums.py @@ -214,6 +214,7 @@ class _StrategyType(LightningEnum): DDP = "ddp" DDP2 = "ddp2" DDP_SPAWN = "ddp_spawn" + DDP_FORK = "ddp_fork" TPU_SPAWN = "tpu_spawn" DEEPSPEED = "deepspeed" HOROVOD = "horovod" @@ -229,6 +230,7 @@ def interactive_compatible_types() -> list[_StrategyType]: return [ _StrategyType.DP, _StrategyType.TPU_SPAWN, + _StrategyType.DDP_FORK, ] def is_interactive_compatible(self) -> bool: diff --git a/tests/tests_pytorch/accelerators/test_accelerator_connector.py b/tests/tests_pytorch/accelerators/test_accelerator_connector.py index 65dea7bfd5ea1..a04418b62ebd9 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_connector.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_connector.py @@ -97,7 +97,7 @@ def _test_strategy_choice_ddp_and_cpu(tmpdir, ddp_strategy_class): "SLURM_LOCALID": "0", }, ) -@mock.patch("torch.cuda.device_count", return_value=0) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) def test_custom_cluster_environment_in_slurm_environment(_, tmpdir): """Test that we choose the custom cluster even when SLURM or TE flags are around.""" @@ -134,7 +134,7 @@ def creates_processes_externally(self) -> bool: "SLURM_LOCALID": "0", }, ) -@mock.patch("torch.cuda.device_count", return_value=0) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): @@ -193,7 +193,7 @@ class Strat(DDPStrategy): "SLURM_LOCALID": "0", }, ) -@mock.patch("torch.cuda.device_count", return_value=0) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_dist_backend_accelerator_mapping(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -202,7 +202,7 @@ def test_dist_backend_accelerator_mapping(*_): assert trainer.strategy.local_rank == 0 -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) def test_ipython_incompatible_backend_error(_, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): @@ -219,18 +219,26 @@ def test_ipython_incompatible_backend_error(_, monkeypatch): Trainer(strategy="dp") -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) trainer = Trainer(strategy="dp", accelerator="gpu") - assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible + assert trainer.strategy.launcher is None +@RunIf(skip_windows=True) @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) -def test_ipython_compatible_strategy_tpu(mock_tpu_acc_avail, monkeypatch): +def test_ipython_compatible_strategy_tpu(_, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) trainer = Trainer(accelerator="tpu") - assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible + assert trainer.strategy.launcher.is_interactive_compatible + + +@RunIf(skip_windows=True) +def test_ipython_compatible_strategy_ddp_fork(monkeypatch): + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) + trainer = Trainer(strategy="ddp_fork", accelerator="cpu") + assert trainer.strategy.launcher.is_interactive_compatible @pytest.mark.parametrize( @@ -244,8 +252,8 @@ def test_ipython_compatible_strategy_tpu(mock_tpu_acc_avail, monkeypatch): ], ) @pytest.mark.parametrize("devices", [1, 2]) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) def test_accelerator_choice_multi_node_gpu( mock_is_available, mock_device_count, tmpdir, strategy, strategy_class, devices ): @@ -253,7 +261,7 @@ def test_accelerator_choice_multi_node_gpu( assert isinstance(trainer.strategy, strategy_class) -@mock.patch("torch.cuda.is_available", return_value=False) +@mock.patch("pytorch_lightning.accelerators.cuda.device_parser.num_cuda_devices", return_value=0) def test_accelerator_cpu(_): trainer = Trainer(accelerator="cpu") assert isinstance(trainer.accelerator, CPUAccelerator) @@ -275,8 +283,8 @@ def test_accelerator_cpu(_): Trainer(accelerator="cpu", gpus=1) -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) @pytest.mark.parametrize("device_count", (["0"], [0, "1"], ["GPU"], [["0", "1"], [0, 1]], [False])) def test_accelererator_invalid_type_devices(mock_is_available, mock_device_count, device_count): with pytest.raises( @@ -409,15 +417,25 @@ def test_amp_level_raises_error_with_native(): def test_strategy_choice_ddp_spawn_cpu(): - trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2) + trainer = Trainer(strategy="ddp_spawn", accelerator="cpu", devices=2) assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.strategy, DDPSpawnStrategy) assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) + assert trainer.strategy.launcher._start_method == "spawn" + + +@RunIf(skip_windows=True) +def test_strategy_choice_ddp_fork_cpu(): + trainer = Trainer(strategy="ddp_fork", accelerator="cpu", devices=2) + assert isinstance(trainer.accelerator, CPUAccelerator) + assert isinstance(trainer.strategy, DDPSpawnStrategy) + assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) + assert trainer.strategy.launcher._start_method == "fork" @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) def test_strategy_choice_ddp(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=1) assert isinstance(trainer.accelerator, CUDAAccelerator) @@ -426,8 +444,8 @@ def test_strategy_choice_ddp(*_): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="gpu", devices=1) assert isinstance(trainer.accelerator, CUDAAccelerator) @@ -472,10 +490,10 @@ def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): }, ) @mock.patch("torch.cuda.set_device") -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) def test_strategy_choice_ddp_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=2) assert isinstance(trainer.accelerator, CUDAAccelerator) @@ -496,7 +514,7 @@ def test_strategy_choice_ddp_te(*_): "TORCHELASTIC_RUN_ID": "1", }, ) -@mock.patch("torch.cuda.device_count", return_value=0) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_strategy_choice_ddp_cpu_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -519,10 +537,9 @@ def test_strategy_choice_ddp_cpu_te(*_): }, ) @mock.patch("torch.cuda.set_device") -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) -@mock.patch("torch.cuda.is_available", return_value=True) def test_strategy_choice_ddp_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=1) assert isinstance(trainer.accelerator, CUDAAccelerator) @@ -542,7 +559,7 @@ def test_strategy_choice_ddp_kubeflow(*_): "RANK": "1", }, ) -@mock.patch("torch.cuda.device_count", return_value=0) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_strategy_choice_ddp_cpu_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -564,7 +581,7 @@ def test_strategy_choice_ddp_cpu_kubeflow(*_): "SLURM_LOCALID": "0", }, ) -@mock.patch("torch.cuda.device_count", return_value=0) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) @pytest.mark.parametrize("strategy", ["ddp", DDPStrategy()]) def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock, strategy): @@ -614,22 +631,19 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): Trainer(accelerator="ipu", precision=64) -@mock.patch("torch.cuda.is_available", return_value=False) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.utilities.imports._TPU_AVAILABLE", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._HPU_AVAILABLE", return_value=False) -def test_devices_auto_choice_cpu( - is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock, is_hpu_available_mock -): +def test_devices_auto_choice_cpu(*_): trainer = Trainer(accelerator="auto", devices="auto") assert trainer.num_devices == 1 -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) @RunIf(mps=False) def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock): - trainer = Trainer(accelerator="auto", devices="auto") assert isinstance(trainer.accelerator, CUDAAccelerator) assert trainer.num_devices == 2 @@ -733,3 +747,13 @@ def test_accelerator_specific_checkpoint_io(*_): ckpt_plugin = TorchCheckpointIO() trainer = Trainer(accelerator="hpu", strategy=HPUParallelStrategy(), plugins=[ckpt_plugin]) assert trainer.strategy.checkpoint_io is ckpt_plugin + + +@pytest.mark.parametrize("strategy", ["ddp_fork", "ddp_fork_find_unused_parameters_false"]) +@mock.patch( + "pytorch_lightning.trainer.connectors.accelerator_connector.torch.multiprocessing.get_all_start_methods", + return_value=[], +) +def test_ddp_fork_on_unsupported_platform(_, strategy): + with pytest.raises(ValueError, match="process forking is not supported on this platform"): + Trainer(strategy=strategy) diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 9395c7e84c709..8c4ac8f3fd4ae 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -18,8 +18,8 @@ from pytorch_lightning.strategies import DDPStrategy -@mock.patch("torch.cuda.device_count", return_value=2) -def test_auto_device_count(device_count_mock): +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +def test_auto_device_count(_): assert CPUAccelerator.auto_device_count() == 1 assert CUDAAccelerator.auto_device_count() == 2 assert TPUAccelerator.auto_device_count() == 8 diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py index de02cba564c0a..12aca123eacc1 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py @@ -45,6 +45,7 @@ from pytorch_lightning.strategies import DDP2Strategy, ParallelStrategy from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY @@ -902,8 +903,8 @@ def test_trainer_config_device_ids(): ], ) def test_root_gpu_property(monkeypatch, gpus, expected_root_gpu, strategy): - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 16) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 16) with pytest.deprecated_call( match="`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. " "Please use `Trainer.strategy.root_device.index` instead." @@ -920,7 +921,7 @@ def test_root_gpu_property(monkeypatch, gpus, expected_root_gpu, strategy): ], ) def test_root_gpu_property_0_passing(monkeypatch, gpus, expected_root_gpu, strategy): - monkeypatch.setattr(torch.cuda, "device_count", lambda: 0) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 0) with pytest.deprecated_call( match="`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. " "Please use `Trainer.strategy.root_device.index` instead." @@ -940,8 +941,8 @@ def test_root_gpu_property_0_passing(monkeypatch, gpus, expected_root_gpu, strat ], ) def test_trainer_gpu_parse(monkeypatch, gpus, expected_num_gpus, strategy): - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 16) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 16) with pytest.deprecated_call( match="`Trainer.num_gpus` was deprecated in v1.6 and will be removed in v1.8." " Please use `Trainer.num_devices` instead." @@ -957,7 +958,7 @@ def test_trainer_gpu_parse(monkeypatch, gpus, expected_num_gpus, strategy): ], ) def test_trainer_num_gpu_0(monkeypatch, gpus, expected_num_gpus, strategy): - monkeypatch.setattr(torch.cuda, "device_count", lambda: 0) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 0) with pytest.deprecated_call( match="`Trainer.num_gpus` was deprecated in v1.6 and will be removed in v1.8." " Please use `Trainer.num_devices` instead." @@ -1019,8 +1020,8 @@ def test_trainer_config_ipus(monkeypatch, trainer_kwargs, expected_ipus): ) def test_trainer_num_processes(monkeypatch, trainer_kwargs, expected_num_processes): if trainer_kwargs.get("accelerator") == "gpu": - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 16) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 16) trainer = Trainer(**trainer_kwargs) with pytest.deprecated_call( match="`Trainer.num_processes` is deprecated in v1.6 and will be removed in v1.8. " @@ -1044,8 +1045,8 @@ def test_trainer_num_processes(monkeypatch, trainer_kwargs, expected_num_process def test_trainer_data_parallel_device_ids(monkeypatch, trainer_kwargs, expected_data_parallel_device_ids): """Test multi type argument with bool.""" if trainer_kwargs.get("accelerator") == "gpu": - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 2) trainer = Trainer(**trainer_kwargs) with pytest.deprecated_call( @@ -1127,8 +1128,8 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): ], ) def test_trainer_gpus(monkeypatch, trainer_kwargs): - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 4) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 4) trainer = Trainer(**trainer_kwargs) with pytest.deprecated_call( match=( @@ -1139,6 +1140,7 @@ def test_trainer_gpus(monkeypatch, trainer_kwargs): assert trainer.gpus == trainer_kwargs["devices"] +@RunIf(skip_windows=True) def test_trainer_tpu_cores(monkeypatch): monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda _: True) trainer = Trainer(accelerator="tpu", devices=8) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py index c54afd0931cff..b39c6dafc1696 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback +from tests_pytorch.helpers.runif import RunIf def test_v2_0_0_deprecated_num_processes(): @@ -27,13 +28,14 @@ def test_v2_0_0_deprecated_num_processes(): _ = Trainer(num_processes=2) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) def test_v2_0_0_deprecated_gpus(*_): with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."): _ = Trainer(gpus=0) +@RunIf(skip_windows=True) @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) def test_v2_0_0_deprecated_tpu_cores(*_): diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 7166be0981846..6d0c0fe891695 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -267,6 +267,7 @@ def test_seed_everything(): _StrategyType.DP, _StrategyType.DDP, _StrategyType.DDP_SPAWN, + pytest.param(_StrategyType.DDP_FORK, marks=RunIf(skip_windows=True)), pytest.param(_StrategyType.DEEPSPEED, marks=RunIf(deepspeed=True)), pytest.param(_StrategyType.DDP_SHARDED, marks=RunIf(fairscale=True)), pytest.param(_StrategyType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), @@ -295,6 +296,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): _StrategyType.DP, _StrategyType.DDP, _StrategyType.DDP_SPAWN, + pytest.param(_StrategyType.DDP_FORK, marks=RunIf(skip_windows=True)), pytest.param(_StrategyType.DEEPSPEED, marks=RunIf(deepspeed=True)), pytest.param(_StrategyType.DDP_SHARDED, marks=RunIf(fairscale=True)), pytest.param(_StrategyType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index bdc61ca399e12..1a2d72a12118e 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -83,8 +83,8 @@ def device_count(): def is_available(): return True - monkeypatch.setattr(torch.cuda, "is_available", is_available) - monkeypatch.setattr(torch.cuda, "device_count", device_count) + monkeypatch.setattr(device_parser, "is_cuda_available", is_available) + monkeypatch.setattr(device_parser, "num_cuda_devices", device_count) @pytest.fixture @@ -92,7 +92,7 @@ def mocked_device_count_0(monkeypatch): def device_count(): return 0 - monkeypatch.setattr(torch.cuda, "device_count", device_count) + monkeypatch.setattr(device_parser, "num_cuda_devices", device_count) # Asking for a gpu when non are available will result in a MisconfigurationException @@ -185,8 +185,8 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun "TORCHELASTIC_RUN_ID": "1", }, ) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) @pytest.mark.parametrize("gpus", [[0, 1, 2], 2, "0", [0, 2]]) def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus): """Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device That we omit diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index 5b6b3db334219..132d13c054926 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -45,8 +45,8 @@ class MyApexPlugin(ApexMixedPrecisionPlugin): "SLURM_LOCALID": "0", }, ) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) @pytest.mark.parametrize("strategy,devices", [("ddp", 2), ("ddp_spawn", 2)]) @pytest.mark.parametrize( "amp,custom_plugin,plugin_cls", @@ -272,16 +272,16 @@ def test_precision_selection_raises(monkeypatch): with pytest.raises(MisconfigurationException, match=r"amp_type='apex', precision='bf16'\)` but it's not supported"): Trainer(amp_backend="apex", precision="bf16") - with mock.patch("torch.cuda.device_count", return_value=1), pytest.raises( + with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), pytest.raises( MisconfigurationException, match="Sharded plugins are not supported with apex" ): - with mock.patch("torch.cuda.is_available", return_value=True): + with mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True): Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1, strategy="ddp_fully_sharded") import pytorch_lightning.plugins.precision.apex_amp as apex monkeypatch.setattr(apex, "_APEX_AVAILABLE", False) - with mock.patch("torch.cuda.device_count", return_value=1), mock.patch( - "torch.cuda.is_available", return_value=True + with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), mock.patch( + "pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True ), pytest.raises(MisconfigurationException, match="asked for Apex AMP but you have not installed it"): Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1) diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index c413b0015db61..b9f39336d11f7 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -85,8 +85,8 @@ def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strat dict(strategy="ddp_spawn", accelerator="gpu", devices=[1, 2]), ], ) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=4) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=4) def test_ranks_available_automatic_strategy_selection(mock0, mock1, trainer_kwargs): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 diff --git a/tests/tests_pytorch/strategies/launchers/__init__.py b/tests/tests_pytorch/strategies/launchers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_pytorch/strategies/launchers/test_spawn.py b/tests/tests_pytorch/strategies/launchers/test_spawn.py new file mode 100644 index 0000000000000..3bb2e94175477 --- /dev/null +++ b/tests/tests_pytorch/strategies/launchers/test_spawn.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import ANY, Mock + +import pytest + +from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher + + +@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp.get_all_start_methods", return_value=[]) +def test_spawn_launcher_forking_on_unsupported_platform(_): + with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"): + _SpawnLauncher(strategy=Mock(), start_method="fork") + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp") +def test_spawn_launcher_start_method(mp_mock, start_method): + mp_mock.get_all_start_methods.return_value = [start_method] + launcher = _SpawnLauncher(strategy=Mock(), start_method=start_method) + launcher.launch(function=Mock()) + mp_mock.get_context.assert_called_with(start_method) + mp_mock.start_processes.assert_called_with( + ANY, + args=ANY, + nprocs=ANY, + start_method=start_method, + ) diff --git a/tests/tests_pytorch/strategies/test_bagua_strategy.py b/tests/tests_pytorch/strategies/test_bagua_strategy.py index c9ccae43edbf3..79ec701964f8f 100644 --- a/tests/tests_pytorch/strategies/test_bagua_strategy.py +++ b/tests/tests_pytorch/strategies/test_bagua_strategy.py @@ -118,6 +118,6 @@ def test_bagua_not_available(monkeypatch): import pytorch_lightning.strategies.bagua as imports monkeypatch.setattr(imports, "_BAGUA_AVAILABLE", False) - with mock.patch("torch.cuda.device_count", return_value=1): + with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1): with pytest.raises(MisconfigurationException, match="you must have `Bagua` installed"): Trainer(strategy="bagua", accelerator="gpu", devices=1) diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 003fe2250b575..4610f6153386b 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -80,11 +80,13 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir, as_module): @RunIf(skip_windows=True) @pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) def test_torch_distributed_backend_env_variables(tmpdir): """This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError.""" _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} - with patch.dict(os.environ, _environ), patch("torch.cuda.device_count", return_value=2): + with patch.dict(os.environ, _environ), patch( + "pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2 + ): with pytest.deprecated_call(match="Environment variable `PL_TORCH_DISTRIBUTED_BACKEND` was deprecated in v1.6"): with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): model = BoringModel() @@ -101,9 +103,9 @@ def test_torch_distributed_backend_env_variables(tmpdir): @RunIf(skip_windows=True) @mock.patch("torch.cuda.set_device") -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("pytorch_lightning.accelerators.cuda.CUDAAccelerator.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("pytorch_lightning.accelerators.gpu.CUDAAccelerator.is_available", return_value=True) @mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "gloo"}, clear=True) def test_ddp_torch_dist_is_available_in_setup( mock_gpu_is_available, mock_device_count, mock_cuda_available, mock_set_device, tmpdir diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index c1120fa4e2be9..2790f014c7212 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -29,8 +29,8 @@ def test_invalid_on_cpu(tmpdir): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) @RunIf(fairscale_fully_sharded=True) def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index d77319249b23d..79562134f9ccb 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -176,7 +176,7 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config): @RunIf(deepspeed=True) -@mock.patch("torch.cuda.device_count", return_value=1) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) @pytest.mark.parametrize("precision", [16, "mixed"]) @pytest.mark.parametrize( "amp_backend", diff --git a/tests/tests_pytorch/strategies/test_dp.py b/tests/tests_pytorch/strategies/test_dp.py index 4a1c504e12bf8..30e0e5b19a845 100644 --- a/tests/tests_pytorch/strategies/test_dp.py +++ b/tests/tests_pytorch/strategies/test_dp.py @@ -154,8 +154,8 @@ def _assert_extra_outputs(self, outputs): assert out.dtype is torch.float -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) def test_dp_raise_exception_with_batch_transfer_hooks(mock_is_available, mock_device_count, tmpdir): """Test that an exception is raised when overriding batch_transfer_hooks in DP model.""" diff --git a/tests/tests_pytorch/strategies/test_strategy_registry.py b/tests/tests_pytorch/strategies/test_strategy_registry.py index 5f9e6208c4fa5..f5576fa14eb8a 100644 --- a/tests/tests_pytorch/strategies/test_strategy_registry.py +++ b/tests/tests_pytorch/strategies/test_strategy_registry.py @@ -79,7 +79,8 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy): assert isinstance(trainer.strategy, DeepSpeedStrategy) -def test_tpu_spawn_debug_strategy_registry(tmpdir): +@RunIf(skip_windows=True) +def test_tpu_spawn_debug_strategy_registry(): strategy = "tpu_spawn_debug" @@ -105,22 +106,41 @@ def test_fsdp_strategy_registry(tmpdir): @pytest.mark.parametrize( - "strategy_name, strategy", + "strategy_name, strategy, expected_init_params", [ - ("ddp_find_unused_parameters_false", DDPStrategy), - ("ddp_spawn_find_unused_parameters_false", DDPSpawnStrategy), - ("ddp_sharded_spawn_find_unused_parameters_false", DDPSpawnShardedStrategy), - ("ddp_sharded_find_unused_parameters_false", DDPShardedStrategy), + ( + "ddp_find_unused_parameters_false", + DDPStrategy, + {"find_unused_parameters": False}, + ), + ( + "ddp_spawn_find_unused_parameters_false", + DDPSpawnStrategy, + {"find_unused_parameters": False, "start_method": "spawn"}, + ), + pytest.param( + "ddp_fork_find_unused_parameters_false", + DDPSpawnStrategy, + {"find_unused_parameters": False, "start_method": "fork"}, + marks=RunIf(skip_windows=True), + ), + ( + "ddp_sharded_spawn_find_unused_parameters_false", + DDPSpawnShardedStrategy, + {"find_unused_parameters": False}, + ), + ( + "ddp_sharded_find_unused_parameters_false", + DDPShardedStrategy, + {"find_unused_parameters": False}, + ), ], ) -def test_ddp_find_unused_parameters_strategy_registry(tmpdir, strategy_name, strategy): - +def test_ddp_find_unused_parameters_strategy_registry(tmpdir, strategy_name, strategy, expected_init_params): trainer = Trainer(default_root_dir=tmpdir, strategy=strategy_name) - assert isinstance(trainer.strategy, strategy) - assert strategy_name in StrategyRegistry - assert StrategyRegistry[strategy_name]["init_params"] == {"find_unused_parameters": False} + assert StrategyRegistry[strategy_name]["init_params"] == expected_init_params assert StrategyRegistry[strategy_name]["strategy"] == strategy diff --git a/tests/tests_pytorch/trainer/flags/test_env_vars.py b/tests/tests_pytorch/trainer/flags/test_env_vars.py index e7c9a13a0cd3c..9e7bd70468482 100644 --- a/tests/tests_pytorch/trainer/flags/test_env_vars.py +++ b/tests/tests_pytorch/trainer/flags/test_env_vars.py @@ -46,8 +46,8 @@ def test_passing_env_variables_defaults(): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2"}) -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) def test_passing_env_variables_devices(cuda_available_mock, device_count_mock): """Testing overwriting trainer arguments.""" trainer = Trainer() diff --git a/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py b/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py index 3800f5bc8c529..aa9f15bc43c18 100644 --- a/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py +++ b/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py @@ -42,13 +42,13 @@ def test_pick_multiple_gpus(nb, expected_gpu_idxs, expected_error): assert expected_gpu_idxs == pick_multiple_gpus(nb) -@mock.patch("torch.cuda.device_count", return_value=1) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) def test_pick_multiple_gpus_more_than_available(*_): with pytest.raises(MisconfigurationException, match="You requested 3 GPUs but your machine only has 1 GPUs"): pick_multiple_gpus(3) -@mock.patch("torch.cuda.device_count", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) @mock.patch("pytorch_lightning.trainer.connectors.accelerator_connector.pick_multiple_gpus", return_value=[1]) def test_auto_select_gpus(*_): diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 4929b2a801a70..35a8a0a8d5789 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -16,13 +16,13 @@ from unittest import mock import pytest -import torch from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.strategies.ipu import IPUStrategy +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf @@ -126,8 +126,8 @@ def test_num_stepping_batches_accumulate_gradients(accumulate_grad_batches, expe ) def test_num_stepping_batches_gpu(trainer_kwargs, estimated_steps, monkeypatch): """Test stepping batches with GPU strategies.""" - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 7) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 7) trainer = Trainer(max_epochs=1, devices=7, accelerator="gpu", **trainer_kwargs) model = BoringModel() trainer._data_connector.attach_data(model) diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index 22b10c8451b70..324070fa87602 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -314,8 +314,8 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -@mock.patch("torch.cuda.device_count", return_value=2) -@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) @pytest.mark.parametrize("use_fault_tolerant", [False, True]) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) def test_combined_data_loader_validation_test( diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index c46c0168db558..3c82e6de84a65 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -52,6 +52,7 @@ SingleDeviceStrategy, ) from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 @@ -1231,8 +1232,8 @@ def __init__(self, **kwargs): "trainer_params", [{"max_epochs": 1, "accelerator": "gpu", "devices": 1}, {"max_epochs": 1, "accelerator": "gpu", "devices": [0]}], ) -@mock.patch("torch.cuda.is_available", return_value=True) -@mock.patch("torch.cuda.device_count", return_value=1) +@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) def test_trainer_omegaconf(_, __, trainer_params): config = OmegaConf.create(trainer_params) Trainer(**config) @@ -2080,8 +2081,8 @@ def training_step(self, batch, batch_idx): ) def test_trainer_config_strategy(monkeypatch, trainer_kwargs, strategy_cls, strategy_name, accelerator_cls, devices): if trainer_kwargs.get("accelerator") == "gpu": - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: trainer_kwargs["devices"]) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: trainer_kwargs["devices"]) trainer = Trainer(**trainer_kwargs) @@ -2147,8 +2148,8 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st ) def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids): if trainer_kwargs.get("accelerator") == "gpu": - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 4) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 4) elif trainer_kwargs.get("accelerator") == "ipu": monkeypatch.setattr(pytorch_lightning.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True) monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", lambda: True) diff --git a/tests/tests_pytorch/trainer/test_trainer_cli.py b/tests/tests_pytorch/trainer/test_trainer_cli.py index 989a06f4193ed..468650e234f81 100644 --- a/tests/tests_pytorch/trainer/test_trainer_cli.py +++ b/tests/tests_pytorch/trainer/test_trainer_cli.py @@ -17,11 +17,10 @@ from unittest import mock import pytest -import torch import tests_pytorch.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities import argparse +from pytorch_lightning.utilities import argparse, device_parser @mock.patch("argparse.ArgumentParser.parse_args") @@ -167,8 +166,8 @@ def test_argparse_args_parsing_fast_dev_run(cli_args, expected): def test_argparse_args_parsing_devices(cli_args, expected_parsed, monkeypatch): """Test multi type argument with bool.""" - monkeypatch.setattr(torch.cuda, "is_available", lambda: True) - monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) + monkeypatch.setattr(device_parser, "is_cuda_available", lambda: True) + monkeypatch.setattr(device_parser, "num_cuda_devices", lambda: 1) cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): diff --git a/tests/tests_pytorch/utilities/test_cli.py b/tests/tests_pytorch/utilities/test_cli.py index caafa9a3ca719..59efd41d26140 100644 --- a/tests/tests_pytorch/utilities/test_cli.py +++ b/tests/tests_pytorch/utilities/test_cli.py @@ -201,8 +201,8 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): ) def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" - monkeypatch.setattr("torch.cuda.device_count", lambda: 2) - monkeypatch.setattr("torch.cuda.is_available", lambda: True) + monkeypatch.setattr("pytorch_lightning.utilities.device_parser.num_cuda_devices", lambda: 2) + monkeypatch.setattr("pytorch_lightning.utilities.device_parser.is_cuda_available", lambda: True) cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("sys.argv", ["any.py"] + cli_args): parser = LightningArgumentParser(add_help=False, parse_as_dict=False) diff --git a/tests/tests_pytorch/utilities/test_device_parser.py b/tests/tests_pytorch/utilities/test_device_parser.py new file mode 100644 index 0000000000000..d496db487f55c --- /dev/null +++ b/tests/tests_pytorch/utilities/test_device_parser.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest +import torch + +from pytorch_lightning.utilities import device_parser + + +@pytest.mark.skipif( + "fork" in torch.multiprocessing.get_all_start_methods(), reason="Requires platform without forking support" +) +@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("torch.cuda.device_count", return_value=2) +def test_num_cuda_devices_without_forking(*_): + """This merely tests that on platforms without fork support our helper functions fall back to the default + implementation for determining cuda availability.""" + assert device_parser.is_cuda_available() + assert device_parser.num_cuda_devices() == 2