diff --git a/.gitignore b/.gitignore index 7040a912974e1..0f03c69600bed 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ src/lightning_app/ui/* *examples/template_react_ui* hars* artifacts/* +*docs/examples* diff --git a/docs/source-app/api_reference/components.rst b/docs/source-app/api_reference/components.rst index 76a99402ddecc..c5f99f0f96629 100644 --- a/docs/source-app/api_reference/components.rst +++ b/docs/source-app/api_reference/components.rst @@ -20,5 +20,6 @@ ___________________ ~python.popen.PopenPythonScript ~python.tracer.TracerPythonScript + ~training.LightningTrainingComponent ~serve.gradio.ServeGradio ~serve.serve.ModelInferenceAPI diff --git a/examples/app_multi_node/app.py b/examples/app_multi_node/app.py new file mode 100644 index 0000000000000..6e405a346a143 --- /dev/null +++ b/examples/app_multi_node/app.py @@ -0,0 +1,11 @@ +from lightning import LightningApp +from lightning.app.components.training import LightningTrainingComponent +from lightning.app.utilities.packaging.cloud_compute import CloudCompute + +app = LightningApp( + LightningTrainingComponent( + "train.py", + num_nodes=2, + cloud_compute=CloudCompute("gpu-fast-multi"), + ), +) diff --git a/examples/app_multi_node/.gitignore b/examples/app_multi_node/bare/.gitignore similarity index 100% rename from examples/app_multi_node/.gitignore rename to examples/app_multi_node/bare/.gitignore diff --git a/examples/app_multi_node/multi_node.py b/examples/app_multi_node/bare/multi_node.py similarity index 100% rename from examples/app_multi_node/multi_node.py rename to examples/app_multi_node/bare/multi_node.py diff --git a/examples/app_multi_node/requirements.txt b/examples/app_multi_node/bare/requirements.txt similarity index 100% rename from examples/app_multi_node/requirements.txt rename to examples/app_multi_node/bare/requirements.txt diff --git a/examples/app_multi_node/train.py b/examples/app_multi_node/train.py new file mode 100644 index 0000000000000..f14809354f405 --- /dev/null +++ b/examples/app_multi_node/train.py @@ -0,0 +1,7 @@ +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel + +if __name__ == "__main__": + model = BoringModel() + trainer = Trainer(max_epochs=1) + trainer.fit(model) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 95a7000818b78..89fcd615430aa 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602)) +- Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830)) + ### Changed - Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537)) diff --git a/src/lightning_app/components/python/tracer.py b/src/lightning_app/components/python/tracer.py index fa955646acbbf..b98c782e138e4 100644 --- a/src/lightning_app/components/python/tracer.py +++ b/src/lightning_app/components/python/tracer.py @@ -2,16 +2,24 @@ import os import signal import sys -from typing import Any, Dict, List, Optional, Union +from copy import deepcopy +from typing import Any, Dict, List, Optional, TypedDict, Union from lightning_app import LightningWork +from lightning_app.storage.drive import Drive from lightning_app.storage.payload import Payload from lightning_app.utilities.app_helpers import _collect_child_process_pids +from lightning_app.utilities.packaging.tarfile import clean_tarfile, extract_tarfile from lightning_app.utilities.tracer import Tracer logger = logging.getLogger(__name__) +class Code(TypedDict): + drive: Drive + name: str + + class TracerPythonScript(LightningWork): def on_before_run(self): """Called before the python script is executed.""" @@ -31,6 +39,7 @@ def __init__( script_args: Optional[Union[list, str]] = None, outputs: Optional[List[str]] = None, env: Optional[Dict] = None, + code: Optional[Code] = None, **kwargs, ): """The TracerPythonScript class enables to easily run a python script. @@ -97,17 +106,46 @@ def __init__( if isinstance(script_args, str): script_args = script_args.split(" ") self.script_args = script_args if script_args else [] + self.original_args = deepcopy(self.script_args) self.env = env self.outputs = outputs or [] for name in self.outputs: setattr(self, name, None) + self.params = None + self.drive = code.get("drive") if code else None + self.code_name = code.get("name") if code else None + self.restart_count = 0 + + def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, **kwargs): + """ + Arguments: + params: A dictionary of arguments to be be added to script_args. + restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks. + """ + if restart_count: + self.restart_count = restart_count + + if params: + self.params = params + self.script_args = self.original_args + [self._to_script_args(k, v) for k, v in params.items()] + + if self.drive: + assert self.code_name + if os.path.exists(self.code_name): + clean_tarfile(self.code_name, "r:gz") + + if self.code_name in self.drive.list(): + self.drive.get(self.code_name) + extract_tarfile(self.code_name, ".", "r:gz") - def run(self, **kwargs): if not os.path.exists(self.script_path): raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.") + kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()} + init_globals = globals() init_globals.update(kwargs) + self.on_before_run() env_copy = os.environ.copy() if self.env: @@ -125,5 +163,11 @@ def on_exit(self): for child_pid in _collect_child_process_pids(os.getpid()): os.kill(child_pid, signal.SIGTERM) + @staticmethod + def _to_script_args(k: str, v: str) -> str: + if k.startswith("--"): + return f"{k}={v}" + return f"--{k}={v}" + __all__ = ["TracerPythonScript"] diff --git a/src/lightning_app/components/training.py b/src/lightning_app/components/training.py new file mode 100644 index 0000000000000..9773fe9670e52 --- /dev/null +++ b/src/lightning_app/components/training.py @@ -0,0 +1,192 @@ +import logging +import os +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from lightning import CloudCompute +from lightning_app import LightningFlow, structures +from lightning_app.components.python import TracerPythonScript +from lightning_app.storage.path import Path + +_logger = logging.getLogger(__name__) + + +class PyTorchLightningScriptRunner(TracerPythonScript): + def __init__( + self, + script_path: str, + script_args: Optional[Union[list, str]] = None, + node_rank: int = 1, + num_nodes: int = 1, + sanity_serving: bool = False, + cloud_compute: Optional[CloudCompute] = None, + parallel: bool = True, + raise_exception: bool = True, + env: Optional[Dict[str, Any]] = None, + **kwargs, + ): + super().__init__( + script_path, + script_args, + raise_exception=raise_exception, + parallel=parallel, + cloud_compute=cloud_compute, + **kwargs, + ) + self.node_rank = node_rank + self.num_nodes = num_nodes + self.best_model_path = None + self.best_model_score = None + self.monitor = None + self.sanity_serving = sanity_serving + self.has_finished = False + self.env = env + + def configure_tracer(self): + from pytorch_lightning import Trainer + + tracer = super().configure_tracer() + tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware) + return tracer + + def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs) -> None: + if not internal_urls: + # Note: This is called only once. + _logger.info(f"The node {self.node_rank} started !") + return None + + if self.env: + os.environ.update(self.env) + + distributed_env_vars = { + "MASTER_ADDR": internal_urls[0][0], + "MASTER_PORT": str(internal_urls[0][1]), + "NODE_RANK": str(self.node_rank), + "PL_TRAINER_NUM_NODES": str(self.num_nodes), + "PL_TRAINER_DEVICES": "auto", + "PL_TRAINER_ACCELERATOR": "auto", + } + + os.environ.update(distributed_env_vars) + return super().run(**kwargs) + + def on_after_run(self, script_globals): + from pytorch_lightning import Trainer + from pytorch_lightning.cli import LightningCLI + + for v in script_globals.values(): + if isinstance(v, LightningCLI): + trainer = v.trainer + break + elif isinstance(v, Trainer): + trainer = v + break + else: + raise RuntimeError("No trainer instance found.") + + self.monitor = trainer.checkpoint_callback.monitor + + if trainer.checkpoint_callback.best_model_score: + self.best_model_path = Path(trainer.checkpoint_callback.best_model_path) + self.best_model_score = float(trainer.checkpoint_callback.best_model_score) + else: + self.best_model_path = Path(trainer.checkpoint_callback.last_model_path) + + self.has_finished = True + + def _trainer_init_pre_middleware(self, trainer, *args, **kwargs): + if self.node_rank != 0: + return {}, args, kwargs + + from pytorch_lightning.serve import ServableModuleValidator + + callbacks = kwargs.get("callbacks", []) + if self.sanity_serving: + callbacks = callbacks + [ServableModuleValidator()] + kwargs["callbacks"] = callbacks + return {}, args, kwargs + + @property + def is_running_in_cloud(self) -> bool: + return "LIGHTNING_APP_STATE_URL" in os.environ + + +class LightningTrainingComponent(LightningFlow): + def __init__( + self, + script_path: str, + script_args: Optional[Union[list, str]] = None, + num_nodes: int = 1, + cloud_compute: CloudCompute = CloudCompute("default"), + sanity_serving: bool = False, + script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner, + **script_runner_kwargs, + ): + """This component enables performing distributed multi-node multi-device training. + + Example:: + + from lightning import LightningApp + from lightning.app.components.training import LightningTrainingComponent + from lightning.app.utilities.packaging.cloud_compute import CloudCompute + + app = LightningApp( + LightningTrainingComponent( + "train.py", + num_nodes=2, + cloud_compute=CloudCompute("gpu"), + ), + ) + + Arguments: + script_path: Path to the script to be executed. + script_args: The arguments to be pass to the script. + num_nodes: Number of nodes. + cloud_compute: The cloud compute object used in the cloud. + sanity_serving: Whether to validate that the model correctly implements + the ServableModule API + """ + super().__init__() + self.ws = structures.List() + self.has_initialized = False + self.script_path = script_path + self.script_args = script_args + self.num_nodes = num_nodes + self._cloud_compute = cloud_compute # TODO: Add support for cloudCompute + self.sanity_serving = sanity_serving + self._script_runner = script_runner + self._script_runner_kwargs = script_runner_kwargs + + def run(self, **run_kwargs): + if not self.has_initialized: + for node_rank in range(self.num_nodes): + self.ws.append( + self._script_runner( + script_path=self.script_path, + script_args=self.script_args, + cloud_compute=self._cloud_compute, + node_rank=node_rank, + sanity_serving=self.sanity_serving, + num_nodes=self.num_nodes, + **self._script_runner_kwargs, + ) + ) + + self.has_initialized = True + + for work in self.ws: + if all(w.internal_ip for w in self.ws): + internal_urls = [(w.internal_ip, w.port) for w in self.ws] + work.run(internal_urls=internal_urls, **run_kwargs) + if all(w.has_finished for w in self.ws): + for w in self.ws: + w.stop() + else: + work.run() + + @property + def best_model_score(self) -> Optional[float]: + return self.ws[0].best_model_score + + @property + def best_model_paths(self) -> List[Optional[Path]]: + return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))] diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index d1af891476a02..f6b6e34e81538 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -209,6 +209,7 @@ def _attach_backend(flow: "LightningFlow", backend): LightningFlow._attach_backend(flow, backend) for work in structure.works: backend._wrap_run_method(_LightningAppRef().get_current(), work) + work._backend = backend for name in flow._structures: getattr(flow, name)._backend = backend diff --git a/src/lightning_app/runners/backends/backend.py b/src/lightning_app/runners/backends/backend.py index c370c7098b778..87bb103823fd2 100644 --- a/src/lightning_app/runners/backends/backend.py +++ b/src/lightning_app/runners/backends/backend.py @@ -87,7 +87,6 @@ def _prepare_queues(self, app): app.commands_metadata_queue = self.queues.get_commands_metadata_queue(**kw) app.error_queue = self.queues.get_error_queue(**kw) app.delta_queue = self.queues.get_delta_queue(**kw) - app.error_queue = self.queues.get_error_queue(**kw) app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw) app.api_delta_queue = self.queues.get_api_delta_queue(**kw) app.request_queues = {} diff --git a/src/lightning_app/source_code/local.py b/src/lightning_app/source_code/local.py index a42347ac42101..05669dff2f6a5 100644 --- a/src/lightning_app/source_code/local.py +++ b/src/lightning_app/source_code/local.py @@ -94,6 +94,7 @@ def upload(self, url: str) -> None: raise OSError( "cannot upload directory code whose total fize size is greater than 2GB (2e9 bytes)" ) from None + uploader = FileUploader( presigned_url=url, source_file=str(self.package_path), diff --git a/src/lightning_app/structures/dict.py b/src/lightning_app/structures/dict.py index 2aa02d4ebfa50..93e2b161b2e7a 100644 --- a/src/lightning_app/structures/dict.py +++ b/src/lightning_app/structures/dict.py @@ -58,7 +58,10 @@ def __init__(self, **kwargs: T): def __setitem__(self, k, v): from lightning_app import LightningFlow, LightningWork - if "." in k: + if not isinstance(k, str): + raise Exception("The provided key should be an string") + + if isinstance(k, str) and "." in k: raise Exception(f"The provided name {k} contains . which is forbidden.") if self._backend: @@ -67,7 +70,7 @@ def __setitem__(self, k, v): _set_child_name(self, v, k) elif isinstance(v, LightningWork): self._backend._wrap_run_method(_LightningAppRef().get_current(), v) - v._name = f"{self.name}.{k}" + v._name = f"{self.name}.{k}" super().__setitem__(k, v) @property diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index bdf37cacf04a7..cc03f5badec2b 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -23,6 +23,7 @@ from lightning_app.utilities.cloud import _get_project from lightning_app.utilities.imports import _is_playwright_available, requires from lightning_app.utilities.network import _configure_session, LightningClient +from lightning_app.utilities.proxies import ProxyWorkRun if _is_playwright_available(): import playwright @@ -114,6 +115,8 @@ def run_work_isolated(work, *args, start_server: bool = False, **kwargs): # pop the stopped status. call_hash = work._calls["latest_call_hash"] work._calls[call_hash]["statuses"].pop(-1) + if isinstance(work.run, ProxyWorkRun): + work.run = work.run.work_run def browser_context_args(browser_context_args: Dict) -> Dict: diff --git a/src/lightning_app/utilities/network.py b/src/lightning_app/utilities/network.py index 98c7db3d46ff8..a9ebcf37ab564 100644 --- a/src/lightning_app/utilities/network.py +++ b/src/lightning_app/utilities/network.py @@ -48,7 +48,7 @@ def _configure_session() -> Session: return http -def _check_service_url_is_ready(url: str, timeout: float = 0.5) -> bool: +def _check_service_url_is_ready(url: str, timeout: float = 1) -> bool: try: response = requests.get(url, timeout=timeout) return response.status_code in (200, 404) diff --git a/src/lightning_app/utilities/packaging/tarfile.py b/src/lightning_app/utilities/packaging/tarfile.py new file mode 100644 index 0000000000000..123e4e2e0942a --- /dev/null +++ b/src/lightning_app/utilities/packaging/tarfile.py @@ -0,0 +1,39 @@ +import os +import shutil +import tarfile + + +def clean_tarfile(file_path: str, mode: str) -> None: + """This utility removes all files extracted from a tarfile.""" + + if not os.path.exists(file_path): + return None + + with tarfile.open(file_path, mode=mode) as tar_ref: + for member in tar_ref.getmembers(): + p = member.path + if p == "." or not os.path.exists(p): + continue + try: + if os.path.isfile(p): + os.remove(p) + else: + shutil.rmtree(p) + except (FileNotFoundError, OSError, PermissionError): + pass + + if os.path.exists(file_path): + os.remove(file_path) + + +def extract_tarfile(file_path: str, extract_path: str, mode: str) -> None: + """This utility extracts all files from a tarfile.""" + if not os.path.exists(file_path): + return None + + with tarfile.open(file_path, mode=mode) as tar_ref: + for member in tar_ref.getmembers(): + try: + tar_ref.extract(member, path=extract_path, set_attrs=False) + except PermissionError: + raise PermissionError(f"Could not extract tar file {file_path}") diff --git a/src/lightning_app/utilities/state.py b/src/lightning_app/utilities/state.py index 0802a426e7349..5cd7979de09d9 100644 --- a/src/lightning_app/utilities/state.py +++ b/src/lightning_app/utilities/state.py @@ -3,7 +3,7 @@ import logging import os from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from deepdiff import DeepDiff from requests import Session @@ -168,7 +168,7 @@ def __getattr__(self, name: str) -> Union[Any, "AppState"]: # The state needs to be fetched on access if it doesn't exist. self._request_state() - if name in self._state["vars"]: + if name in self._state.get("vars", {}): value = self._state["vars"][name] if isinstance(value, dict): return _maybe_create_drive("root." + ".".join(self._my_affiliation), value) @@ -187,12 +187,23 @@ def __getattr__(self, name: str) -> Union[Any, "AppState"]: state=self._state["flows"][name], ) + elif name in self._state.get("structures", {}): + return AppState( + self._host, + self._port, + last_state=self._last_state["structures"][name], + state=self._state["structures"][name], + ) + raise AttributeError( f"Failed to access '{name}' through `AppState`. The state provides:" f" Variables: {list(self._state['vars'].keys())}," f" Components: {list(self._state.get('flows', {}).keys()) + list(self._state.get('works', {}).keys())}", ) + def __getitem__(self, key: str): + return self.__getattr__(key) + def __setattr__(self, name: str, value: Any) -> None: if name in self._APP_PRIVATE_KEYS: object.__setattr__(self, name, value) @@ -226,6 +237,48 @@ def __repr__(self) -> str: def __bool__(self) -> bool: return bool(self._state) + def __len__(self) -> int: + # The state needs to be fetched on access if it doesn't exist. + self._request_state() + + keys = [] + for component in ["flows", "works", "structures"]: + keys.extend(list(self._state.get(component, {}))) + return len(keys) + + def items(self) -> List[Dict[str, Any]]: + # The state needs to be fetched on access if it doesn't exist. + self._request_state() + + items = [] + for component in ["flows", "works"]: + state = self._state.get(component, {}) + last_state = self._last_state.get(component, {}) + for name, state_value in state.items(): + v = AppState( + self._host, + self._port, + last_state=last_state[name], + state=state_value, + ) + items.append((name, v)) + + structures = self._state.get("structures", {}) + last_structures = self._last_state.get("structures", {}) + if structures: + for component in ["flows", "works"]: + state = structures.get(component, {}) + last_state = last_structures.get(component, {}) + for name, state_value in state.items(): + v = AppState( + self._host, + self._port, + last_state=last_state[name], + state=state_value, + ) + items.append((name, v)) + return items + @staticmethod def _configure_session() -> Session: return _configure_session() diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py index 61969ef1c4c51..678655d6ee908 100644 --- a/tests/tests_app/components/python/test_python.py +++ b/tests/tests_app/components/python/test_python.py @@ -1,11 +1,15 @@ import os +import tarfile import pytest from tests_app import _PROJECT_ROOT from lightning_app.components.python import PopenPythonScript, TracerPythonScript +from lightning_app.components.python.tracer import Code +from lightning_app.storage.drive import Drive from lightning_app.testing.helpers import RunIf from lightning_app.testing.testing import run_work_isolated +from lightning_app.utilities.component import _set_work_context COMPONENTS_SCRIPTS_FOLDER = str(os.path.join(_PROJECT_ROOT, "tests/tests_app/components/python/scripts/")) @@ -69,3 +73,51 @@ def test_tracer_python_script_with_kwargs(): ) run_work_isolated(python_script) assert python_script.has_failed + + +def test_tracer_component_with_code(): + """This test ensures the Tracer Component gets the latest code from the code object that is provided and + arguments are cleaned.""" + + drive = Drive("lit://code") + drive.component_name = "something" + code = Code(drive=drive, name="sample.tar.gz") + + with open("file.py", "w") as f: + f.write('raise Exception("An error")') + + with tarfile.open("sample.tar.gz", "w:gz") as tar: + tar.add("file.py") + + drive.put("sample.tar.gz") + os.remove("file.py") + os.remove("sample.tar.gz") + + python_script = TracerPythonScript("file.py", script_args=["--b=1"], raise_exception=False, code=code) + run_work_isolated(python_script, params={"a": "1"}, restart_count=0) + assert python_script.status.message == "An error" + + with open("file.py", "w") as f: + f.write("import sys\n") + f.write("print(sys.argv)\n") + + with tarfile.open("sample.tar.gz", "w:gz") as tar: + tar.add("file.py") + + _set_work_context() + drive.put("sample.tar.gz") + os.remove("file.py") + os.remove("sample.tar.gz") + + with open("file.py", "w") as f: + f.write('raise Exception("An error")') + + call_hash = python_script._calls["latest_call_hash"] + python_script._calls[call_hash]["statuses"].pop(-1) + python_script._calls[call_hash]["statuses"].pop(-1) + + run_work_isolated(python_script, params={"a": "1"}, restart_count=1) + assert python_script.has_succeeded + assert python_script.script_args == ["--b=1", "--a=1"] + os.remove("file.py") + os.remove("sample.tar.gz") diff --git a/tests/tests_app/utilities/test_state.py b/tests/tests_app/utilities/test_state.py index 0740ffc615b87..3b9f1b790cfc7 100644 --- a/tests/tests_app/utilities/test_state.py +++ b/tests/tests_app/utilities/test_state.py @@ -7,6 +7,7 @@ import lightning_app from lightning_app import LightningApp, LightningFlow, LightningWork +from lightning_app.structures import Dict, List from lightning_app.utilities.app_helpers import AppStatePlugin, BaseStatePlugin from lightning_app.utilities.state import AppState @@ -280,3 +281,41 @@ def test_app_state_with_no_env_var(**__): assert state._host == "http://127.0.0.1" assert state._port == 7501 assert state._url == "http://127.0.0.1:7501" + + +class FlowStructures(LightningFlow): + def __init__(self): + super().__init__() + self.w_list = List(Work(), Work()) + self.w_dict = Dict(**{"toto": Work(), "toto_2": Work()}) + + def run(self): + self._exit() + + +class FlowStructuresEmpty(LightningFlow): + def __init__(self): + super().__init__() + self.w_list = List() + self.w_dict = Dict() + + def run(self): + self._exit() + + +def test_app_state_with_structures(): + app = LightningApp(FlowStructures()) + state = AppState() + state._last_state = app.state + state._state = app.state + assert state.w_list["0"].counter == 0 + assert len(state.w_list) == 2 + assert state.w_dict["toto"].counter == 0 + assert [k for k, _ in state.w_dict.items()] == ["toto", "toto_2"] + assert [k for k, _ in state.w_list.items()] == ["0", "1"] + + app = LightningApp(FlowStructuresEmpty()) + state = AppState() + state._last_state = app.state + state._state = app.state + assert state.w_list diff --git a/tests/tests_app_examples/test_multi_node.py b/tests/tests_app_examples/test_multi_node.py new file mode 100644 index 0000000000000..4b5c80c0cd9cb --- /dev/null +++ b/tests/tests_app_examples/test_multi_node.py @@ -0,0 +1,29 @@ +import os + +from tests_app import _PROJECT_ROOT + +from lightning_app.testing.testing import application_testing, LightningTestApp + + +class LightningTestMultiNodeApp(LightningTestApp): + def on_before_run_once(self): + res = super().on_before_run_once() + if all(w.has_finished for w in self.works): + return True + return res + + +def test_multi_node_example(): + cwd = os.getcwd() + new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node") + os.chdir(new_cwd) + command_line = [ + "app.py", + "--blocking", + "False", + "--open-ui", + "False", + ] + result = application_testing(LightningTestMultiNodeApp, command_line) + assert result.exit_code == 0 + os.chdir(cwd)