Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))
- Fixed `Fabric(strategy="auto")` support ([#16916](https://github.com/Lightning-AI/lightning/pull/16916))


## [1.9.2] - 2023-02-15
Expand Down
8 changes: 6 additions & 2 deletions src/lightning_fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()

# 4. Instantiate Strategy - Part 1
if self._strategy_flag is None:
if self._strategy_flag in (None, "auto"):
self._strategy_flag = self._choose_strategy()
# In specific cases, ignore user selection and fall back to a different strategy
self._check_strategy_and_fallback()
Expand Down Expand Up @@ -184,7 +184,11 @@ def _check_config_and_set_final_flags(
if strategy is not None:
self._strategy_flag = strategy

if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
if (
strategy not in (None, "auto")
and strategy not in self._registered_strategies
and not isinstance(strategy, Strategy)
):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
- Fixed `Fabric(strategy="auto")` support. It will choose DDP over DDP-spawn, contrary to `strategy=None` (default) ([#16916](https://github.com/Lightning-AI/lightning/pull/16916))


## [1.9.2] - 2023-02-15
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()

# 4. Instantiate Strategy - Part 1
if self._strategy_flag is None:
if self._strategy_flag in (None, "auto"):
self._strategy_flag = self._choose_strategy()
# In specific cases, ignore user selection and fall back to a different strategy
self._check_strategy_and_fallback()
Expand Down Expand Up @@ -273,7 +273,11 @@ def _check_config_and_set_final_flags(
" you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
)

if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
if (
strategy not in (None, "auto")
and strategy not in self._registered_strategies
and not isinstance(strategy, Strategy)
):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
" It must be either a string or an instance of `pytorch_lightning.strategies.Strategy`."
Expand Down Expand Up @@ -639,6 +643,9 @@ def _choose_strategy(self) -> Union[Strategy, str]:
if len(self._parallel_devices) > 1:
if _IS_INTERACTIVE:
return "ddp_fork"
if self._strategy_flag == "auto":
# None chooses "ddp_spawn" for backwards compatibility, auto chooses "ddp" for future compatibility
return "ddp"
return "ddp_spawn"

return DDPStrategy.strategy_name
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,3 +893,19 @@ def get_defaults(cls):
# defaults should match on the intersection of argument names
for name, connector_default in connector_defaults.items():
assert connector_default == fabric_defaults[name]


@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2)
@mock.patch("lightning_fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_connector_auto_selection(*_):
connector = _Connector(accelerator="auto", strategy=None, devices="auto")
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.launcher, _SubprocessScriptLauncher)
assert connector._devices_flag == [0, 1]

connector = _Connector(accelerator="auto", strategy="auto", devices="auto")
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.launcher, _SubprocessScriptLauncher)
assert connector._devices_flag == [0, 1]
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,15 @@ def get_defaults(cls):
for name, connector_default in connector_defaults.items():
name = lut.get(name, name)
assert connector_default == trainer_defaults[name]


def test_connector_auto_selection(cuda_count_2, mps_count_0):
trainer = Trainer(accelerator="auto", strategy=None, devices="auto")
assert isinstance(trainer.accelerator, CUDAAccelerator)
assert isinstance(trainer.strategy, DDPSpawnStrategy)
assert trainer.num_devices == 2

trainer = Trainer(accelerator="auto", strategy="auto", devices="auto")
assert isinstance(trainer.accelerator, CUDAAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert trainer.num_devices == 2