Skip to content

(app) Introduce LightningTrainingComponent #13830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b9fc5b8
update
tchaton Jul 23, 2022
5e607e3
update
tchaton Jul 23, 2022
ee161b3
update
tchaton Jul 23, 2022
3fffc0b
update
tchaton Jul 23, 2022
6ddd07a
update
tchaton Jul 23, 2022
acfb717
update
tchaton Jul 23, 2022
9dc1864
update
tchaton Jul 23, 2022
abce33a
update
tchaton Jul 23, 2022
463540f
Merge branch 'master' into add_lightning_training_component
tchaton Jul 23, 2022
6e00b78
update
tchaton Jul 23, 2022
32d0206
Merge branch 'add_lightning_training_component' of https://github.com…
tchaton Jul 23, 2022
aa548c9
update
tchaton Jul 23, 2022
7b8e831
update
tchaton Jul 25, 2022
b0a3c52
update
tchaton Jul 25, 2022
04fe16d
update
tchaton Jul 25, 2022
0e1b06e
update
tchaton Jul 25, 2022
a389de6
update
tchaton Jul 25, 2022
3c9d1f8
update
tchaton Jul 25, 2022
fa28c53
update
tchaton Jul 25, 2022
f6060da
Merge branch 'master' into add_lightning_training_component_2
tchaton Jul 25, 2022
a087275
update
tchaton Jul 25, 2022
6200539
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton Jul 25, 2022
b8c1ff3
update
tchaton Jul 25, 2022
7860ce4
update
tchaton Jul 26, 2022
0701fc8
update
tchaton Jul 26, 2022
c0df7c3
update
tchaton Jul 26, 2022
e596856
update
tchaton Jul 26, 2022
253aa43
update
tchaton Jul 26, 2022
0373fc7
update
tchaton Jul 26, 2022
c5a8e44
Merge branch 'master' into add_lightning_training_component_2
tchaton Jul 27, 2022
71f1cfe
update
tchaton Jul 28, 2022
b300bec
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton Jul 28, 2022
e5a4a09
update
tchaton Jul 28, 2022
e6ff7ac
Merge branch 'master' into add_lightning_training_component_2
tchaton Jul 28, 2022
7cc1c39
update
tchaton Jul 28, 2022
29a1639
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton Jul 28, 2022
eec5e6d
update
tchaton Jul 28, 2022
17b8c96
update
tchaton Jul 28, 2022
01bcc29
Update tests/tests_app/components/python/test_python.py
Felonious-Spellfire Jul 28, 2022
ba572cb
Update src/lightning_app/utilities/packaging/tarfile.py
Felonious-Spellfire Jul 28, 2022
2f5e4b0
Update src/lightning_app/components/training.py
Felonious-Spellfire Jul 28, 2022
1e000e5
Update src/lightning_app/components/training.py
Felonious-Spellfire Jul 28, 2022
4ff559f
Update src/lightning_app/components/python/tracer.py
Felonious-Spellfire Jul 28, 2022
4c9efe4
Update src/lightning_app/CHANGELOG.md
Felonious-Spellfire Jul 28, 2022
d49ed8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2022
3888ba1
update
tchaton Jul 28, 2022
a72397b
update
tchaton Jul 28, 2022
f679105
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton Jul 28, 2022
3fb367a
Merge branch 'master' into add_lightning_training_component_2
tchaton Jul 28, 2022
d3ee31b
update
tchaton Jul 29, 2022
620c960
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton Jul 29, 2022
fc4d807
Merge branch 'master' into add_lightning_training_component_2
tchaton Jul 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ src/lightning_app/ui/*
*examples/template_react_ui*
hars*
artifacts/*
*docs/examples*
1 change: 1 addition & 0 deletions docs/source-app/api_reference/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ ___________________

~python.popen.PopenPythonScript
~python.tracer.TracerPythonScript
~training.LightningTrainingComponent
~serve.gradio.ServeGradio
~serve.serve.ModelInferenceAPI
11 changes: 11 additions & 0 deletions examples/app_multi_node/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lightning import LightningApp
from lightning.app.components.training import LightningTrainingComponent
from lightning.app.utilities.packaging.cloud_compute import CloudCompute

app = LightningApp(
LightningTrainingComponent(
"train.py",
num_nodes=2,
cloud_compute=CloudCompute("gpu-fast-multi"),
),
)
7 changes: 7 additions & 0 deletions examples/app_multi_node/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel

if __name__ == "__main__":
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))

- Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830))

### Changed

- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537))
Expand Down
48 changes: 46 additions & 2 deletions src/lightning_app/components/python/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
import os
import signal
import sys
from typing import Any, Dict, List, Optional, Union
from copy import deepcopy
from typing import Any, Dict, List, Optional, TypedDict, Union

from lightning_app import LightningWork
from lightning_app.storage.drive import Drive
from lightning_app.storage.payload import Payload
from lightning_app.utilities.app_helpers import _collect_child_process_pids
from lightning_app.utilities.packaging.tarfile import clean_tarfile, extract_tarfile
from lightning_app.utilities.tracer import Tracer

logger = logging.getLogger(__name__)


class Code(TypedDict):
drive: Drive
name: str


class TracerPythonScript(LightningWork):
def on_before_run(self):
"""Called before the python script is executed."""
Expand All @@ -31,6 +39,7 @@ def __init__(
script_args: Optional[Union[list, str]] = None,
outputs: Optional[List[str]] = None,
env: Optional[Dict] = None,
code: Optional[Code] = None,
**kwargs,
):
"""The TracerPythonScript class enables to easily run a python script.
Expand Down Expand Up @@ -97,17 +106,46 @@ def __init__(
if isinstance(script_args, str):
script_args = script_args.split(" ")
self.script_args = script_args if script_args else []
self.original_args = deepcopy(self.script_args)
self.env = env
self.outputs = outputs or []
for name in self.outputs:
setattr(self, name, None)
self.params = None
self.drive = code.get("drive") if code else None
self.code_name = code.get("name") if code else None
self.restart_count = 0

def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, **kwargs):
"""
Arguments:
params: A dictionary of arguments to be be added to script_args.
restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks.
"""
if restart_count:
self.restart_count = restart_count

if params:
self.params = params
self.script_args = self.original_args + [self._to_script_args(k, v) for k, v in params.items()]

if self.drive:
assert self.code_name
if os.path.exists(self.code_name):
clean_tarfile(self.code_name, "r:gz")

if self.code_name in self.drive.list():
self.drive.get(self.code_name)
extract_tarfile(self.code_name, ".", "r:gz")

def run(self, **kwargs):
if not os.path.exists(self.script_path):
raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")

kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}

init_globals = globals()
init_globals.update(kwargs)

self.on_before_run()
env_copy = os.environ.copy()
if self.env:
Expand All @@ -125,5 +163,11 @@ def on_exit(self):
for child_pid in _collect_child_process_pids(os.getpid()):
os.kill(child_pid, signal.SIGTERM)

@staticmethod
def _to_script_args(k: str, v: str) -> str:
if k.startswith("--"):
return f"{k}={v}"
return f"--{k}={v}"


__all__ = ["TracerPythonScript"]
192 changes: 192 additions & 0 deletions src/lightning_app/components/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from lightning import CloudCompute
from lightning_app import LightningFlow, structures
from lightning_app.components.python import TracerPythonScript
from lightning_app.storage.path import Path

_logger = logging.getLogger(__name__)


class PyTorchLightningScriptRunner(TracerPythonScript):
def __init__(
self,
script_path: str,
script_args: Optional[Union[list, str]] = None,
node_rank: int = 1,
num_nodes: int = 1,
sanity_serving: bool = False,
cloud_compute: Optional[CloudCompute] = None,
parallel: bool = True,
raise_exception: bool = True,
env: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(
script_path,
script_args,
raise_exception=raise_exception,
parallel=parallel,
cloud_compute=cloud_compute,
**kwargs,
)
self.node_rank = node_rank
self.num_nodes = num_nodes
self.best_model_path = None
self.best_model_score = None
self.monitor = None
self.sanity_serving = sanity_serving
self.has_finished = False
self.env = env

def configure_tracer(self):
from pytorch_lightning import Trainer

tracer = super().configure_tracer()
tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
return tracer

def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs) -> None:
if not internal_urls:
# Note: This is called only once.
_logger.info(f"The node {self.node_rank} started !")
return None

if self.env:
os.environ.update(self.env)

distributed_env_vars = {
"MASTER_ADDR": internal_urls[0][0],
"MASTER_PORT": str(internal_urls[0][1]),
"NODE_RANK": str(self.node_rank),
"PL_TRAINER_NUM_NODES": str(self.num_nodes),
"PL_TRAINER_DEVICES": "auto",
"PL_TRAINER_ACCELERATOR": "auto",
}

os.environ.update(distributed_env_vars)
return super().run(**kwargs)

def on_after_run(self, script_globals):
from pytorch_lightning import Trainer
from pytorch_lightning.cli import LightningCLI

for v in script_globals.values():
if isinstance(v, LightningCLI):
trainer = v.trainer
break
elif isinstance(v, Trainer):
trainer = v
break
else:
raise RuntimeError("No trainer instance found.")

self.monitor = trainer.checkpoint_callback.monitor

if trainer.checkpoint_callback.best_model_score:
self.best_model_path = Path(trainer.checkpoint_callback.best_model_path)
self.best_model_score = float(trainer.checkpoint_callback.best_model_score)
else:
self.best_model_path = Path(trainer.checkpoint_callback.last_model_path)

self.has_finished = True

def _trainer_init_pre_middleware(self, trainer, *args, **kwargs):
if self.node_rank != 0:
return {}, args, kwargs

from pytorch_lightning.serve import ServableModuleValidator

callbacks = kwargs.get("callbacks", [])
if self.sanity_serving:
callbacks = callbacks + [ServableModuleValidator()]
kwargs["callbacks"] = callbacks
return {}, args, kwargs

@property
def is_running_in_cloud(self) -> bool:
return "LIGHTNING_APP_STATE_URL" in os.environ


class LightningTrainingComponent(LightningFlow):
def __init__(
self,
script_path: str,
script_args: Optional[Union[list, str]] = None,
num_nodes: int = 1,
cloud_compute: CloudCompute = CloudCompute("default"),
sanity_serving: bool = False,
script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner,
**script_runner_kwargs,
):
"""This component enables performing distributed multi-node multi-device training.

Example::

from lightning import LightningApp
from lightning.app.components.training import LightningTrainingComponent
from lightning.app.utilities.packaging.cloud_compute import CloudCompute

app = LightningApp(
LightningTrainingComponent(
"train.py",
num_nodes=2,
cloud_compute=CloudCompute("gpu"),
),
)

Arguments:
script_path: Path to the script to be executed.
script_args: The arguments to be pass to the script.
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
sanity_serving: Whether to validate that the model correctly implements
the ServableModule API
"""
super().__init__()
self.ws = structures.List()
self.has_initialized = False
self.script_path = script_path
self.script_args = script_args
self.num_nodes = num_nodes
self._cloud_compute = cloud_compute # TODO: Add support for cloudCompute
self.sanity_serving = sanity_serving
self._script_runner = script_runner
self._script_runner_kwargs = script_runner_kwargs

def run(self, **run_kwargs):
if not self.has_initialized:
for node_rank in range(self.num_nodes):
self.ws.append(
self._script_runner(
script_path=self.script_path,
script_args=self.script_args,
cloud_compute=self._cloud_compute,
node_rank=node_rank,
sanity_serving=self.sanity_serving,
num_nodes=self.num_nodes,
**self._script_runner_kwargs,
)
)

self.has_initialized = True

for work in self.ws:
if all(w.internal_ip for w in self.ws):
internal_urls = [(w.internal_ip, w.port) for w in self.ws]
work.run(internal_urls=internal_urls, **run_kwargs)
if all(w.has_finished for w in self.ws):
for w in self.ws:
w.stop()
else:
work.run()

@property
def best_model_score(self) -> Optional[float]:
return self.ws[0].best_model_score

@property
def best_model_paths(self) -> List[Optional[Path]]:
return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))]
1 change: 1 addition & 0 deletions src/lightning_app/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _attach_backend(flow: "LightningFlow", backend):
LightningFlow._attach_backend(flow, backend)
for work in structure.works:
backend._wrap_run_method(_LightningAppRef().get_current(), work)
work._backend = backend

for name in flow._structures:
getattr(flow, name)._backend = backend
Expand Down
1 change: 0 additions & 1 deletion src/lightning_app/runners/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def _prepare_queues(self, app):
app.commands_metadata_queue = self.queues.get_commands_metadata_queue(**kw)
app.error_queue = self.queues.get_error_queue(**kw)
app.delta_queue = self.queues.get_delta_queue(**kw)
app.error_queue = self.queues.get_error_queue(**kw)
app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw)
app.api_delta_queue = self.queues.get_api_delta_queue(**kw)
app.request_queues = {}
Expand Down
1 change: 1 addition & 0 deletions src/lightning_app/source_code/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def upload(self, url: str) -> None:
raise OSError(
"cannot upload directory code whose total fize size is greater than 2GB (2e9 bytes)"
) from None

uploader = FileUploader(
presigned_url=url,
source_file=str(self.package_path),
Expand Down
7 changes: 5 additions & 2 deletions src/lightning_app/structures/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def __init__(self, **kwargs: T):
def __setitem__(self, k, v):
from lightning_app import LightningFlow, LightningWork

if "." in k:
if not isinstance(k, str):
raise Exception("The provided key should be an string")

if isinstance(k, str) and "." in k:
raise Exception(f"The provided name {k} contains . which is forbidden.")

if self._backend:
Expand All @@ -67,7 +70,7 @@ def __setitem__(self, k, v):
_set_child_name(self, v, k)
elif isinstance(v, LightningWork):
self._backend._wrap_run_method(_LightningAppRef().get_current(), v)
v._name = f"{self.name}.{k}"
v._name = f"{self.name}.{k}"
super().__setitem__(k, v)

@property
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.imports import _is_playwright_available, requires
from lightning_app.utilities.network import _configure_session, LightningClient
from lightning_app.utilities.proxies import ProxyWorkRun

if _is_playwright_available():
import playwright
Expand Down Expand Up @@ -114,6 +115,8 @@ def run_work_isolated(work, *args, start_server: bool = False, **kwargs):
# pop the stopped status.
call_hash = work._calls["latest_call_hash"]
work._calls[call_hash]["statuses"].pop(-1)
if isinstance(work.run, ProxyWorkRun):
work.run = work.run.work_run


def browser_context_args(browser_context_args: Dict) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/utilities/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _configure_session() -> Session:
return http


def _check_service_url_is_ready(url: str, timeout: float = 0.5) -> bool:
def _check_service_url_is_ready(url: str, timeout: float = 1) -> bool:
try:
response = requests.get(url, timeout=timeout)
return response.status_code in (200, 404)
Expand Down
Loading