Skip to content

Better error message when trying to re-initialize CUDA in forked subprocess #14709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 28, 2022
24 changes: 23 additions & 1 deletion src/lightning_lite/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.strategy import Strategy
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.imports import _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states


Expand Down Expand Up @@ -82,6 +82,9 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
Expand Down Expand Up @@ -166,3 +169,22 @@ def restore(self) -> None:
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
torch.backends.cudnn.benchmark = self.cudnn_benchmark
_set_rng_states(self.rng_states)


def _check_bad_cuda_fork() -> None:
"""Checks whether it is safe to fork and initialize CUDA in the new processes, and raises an exception if not.

The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for
Lightning users.
"""
if not torch.cuda.is_initialized():
return

message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))



### Changed

- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import pytorch_lightning as pl
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.launchers.multiprocessing import _check_bad_cuda_fork
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
from lightning_lite.utilities.types import _PATH
Expand Down Expand Up @@ -90,6 +91,9 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
**kwargs: Optional keyword arguments to be passed to the given function.
"""
self._check_torchdistx_support()
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_lite/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ def test_global_state_snapshot():
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123


@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch("torch.cuda.is_initialized", return_value=True)
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
launcher.launch(function=Mock())