-
Notifications
You must be signed in to change notification settings - Fork 3.6k
(app) Introduce LightningTrainingComponent #13830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
b9fc5b8
update
tchaton 5e607e3
update
tchaton ee161b3
update
tchaton 3fffc0b
update
tchaton 6ddd07a
update
tchaton acfb717
update
tchaton 9dc1864
update
tchaton abce33a
update
tchaton 463540f
Merge branch 'master' into add_lightning_training_component
tchaton 6e00b78
update
tchaton 32d0206
Merge branch 'add_lightning_training_component' of https://github.com…
tchaton aa548c9
update
tchaton 7b8e831
update
tchaton b0a3c52
update
tchaton 04fe16d
update
tchaton 0e1b06e
update
tchaton a389de6
update
tchaton 3c9d1f8
update
tchaton fa28c53
update
tchaton f6060da
Merge branch 'master' into add_lightning_training_component_2
tchaton a087275
update
tchaton 6200539
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton b8c1ff3
update
tchaton 7860ce4
update
tchaton 0701fc8
update
tchaton c0df7c3
update
tchaton e596856
update
tchaton 253aa43
update
tchaton 0373fc7
update
tchaton c5a8e44
Merge branch 'master' into add_lightning_training_component_2
tchaton 71f1cfe
update
tchaton b300bec
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton e5a4a09
update
tchaton e6ff7ac
Merge branch 'master' into add_lightning_training_component_2
tchaton 7cc1c39
update
tchaton 29a1639
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton eec5e6d
update
tchaton 17b8c96
update
tchaton 01bcc29
Update tests/tests_app/components/python/test_python.py
Felonious-Spellfire ba572cb
Update src/lightning_app/utilities/packaging/tarfile.py
Felonious-Spellfire 2f5e4b0
Update src/lightning_app/components/training.py
Felonious-Spellfire 1e000e5
Update src/lightning_app/components/training.py
Felonious-Spellfire 4ff559f
Update src/lightning_app/components/python/tracer.py
Felonious-Spellfire 4c9efe4
Update src/lightning_app/CHANGELOG.md
Felonious-Spellfire d49ed8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3888ba1
update
tchaton a72397b
update
tchaton f679105
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton 3fb367a
Merge branch 'master' into add_lightning_training_component_2
tchaton d3ee31b
update
tchaton 620c960
Merge branch 'add_lightning_training_component_2' of https://github.c…
tchaton fc4d807
Merge branch 'master' into add_lightning_training_component_2
tchaton File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,3 +163,4 @@ src/lightning_app/ui/* | |
*examples/template_react_ui* | ||
hars* | ||
artifacts/* | ||
*docs/examples* | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from lightning import LightningApp | ||
from lightning.app.components.training import LightningTrainingComponent | ||
from lightning.app.utilities.packaging.cloud_compute import CloudCompute | ||
|
||
app = LightningApp( | ||
LightningTrainingComponent( | ||
"train.py", | ||
num_nodes=2, | ||
cloud_compute=CloudCompute("gpu-fast-multi"), | ||
), | ||
) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from lightning.pytorch import Trainer | ||
from lightning.pytorch.demos.boring_classes import BoringModel | ||
|
||
if __name__ == "__main__": | ||
model = BoringModel() | ||
trainer = Trainer(max_epochs=1) | ||
trainer.fit(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import logging | ||
import os | ||
from typing import Any, Dict, List, Optional, Tuple, Type, Union | ||
|
||
from lightning import CloudCompute | ||
from lightning_app import LightningFlow, structures | ||
from lightning_app.components.python import TracerPythonScript | ||
from lightning_app.storage.path import Path | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class PyTorchLightningScriptRunner(TracerPythonScript): | ||
def __init__( | ||
self, | ||
script_path: str, | ||
script_args: Optional[Union[list, str]] = None, | ||
node_rank: int = 1, | ||
num_nodes: int = 1, | ||
sanity_serving: bool = False, | ||
cloud_compute: Optional[CloudCompute] = None, | ||
parallel: bool = True, | ||
raise_exception: bool = True, | ||
env: Optional[Dict[str, Any]] = None, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
script_path, | ||
script_args, | ||
raise_exception=raise_exception, | ||
parallel=parallel, | ||
cloud_compute=cloud_compute, | ||
**kwargs, | ||
) | ||
self.node_rank = node_rank | ||
self.num_nodes = num_nodes | ||
self.best_model_path = None | ||
self.best_model_score = None | ||
self.monitor = None | ||
self.sanity_serving = sanity_serving | ||
self.has_finished = False | ||
self.env = env | ||
|
||
def configure_tracer(self): | ||
from pytorch_lightning import Trainer | ||
|
||
tracer = super().configure_tracer() | ||
tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware) | ||
return tracer | ||
|
||
def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs) -> None: | ||
if not internal_urls: | ||
# Note: This is called only once. | ||
_logger.info(f"The node {self.node_rank} started !") | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return None | ||
|
||
if self.env: | ||
os.environ.update(self.env) | ||
|
||
distributed_env_vars = { | ||
"MASTER_ADDR": internal_urls[0][0], | ||
"MASTER_PORT": str(internal_urls[0][1]), | ||
"NODE_RANK": str(self.node_rank), | ||
"PL_TRAINER_NUM_NODES": str(self.num_nodes), | ||
"PL_TRAINER_DEVICES": "auto", | ||
"PL_TRAINER_ACCELERATOR": "auto", | ||
} | ||
|
||
os.environ.update(distributed_env_vars) | ||
return super().run(**kwargs) | ||
|
||
def on_after_run(self, script_globals): | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.cli import LightningCLI | ||
|
||
for v in script_globals.values(): | ||
if isinstance(v, LightningCLI): | ||
trainer = v.trainer | ||
break | ||
elif isinstance(v, Trainer): | ||
trainer = v | ||
break | ||
else: | ||
raise RuntimeError("No trainer instance found.") | ||
|
||
self.monitor = trainer.checkpoint_callback.monitor | ||
|
||
if trainer.checkpoint_callback.best_model_score: | ||
self.best_model_path = Path(trainer.checkpoint_callback.best_model_path) | ||
self.best_model_score = float(trainer.checkpoint_callback.best_model_score) | ||
else: | ||
self.best_model_path = Path(trainer.checkpoint_callback.last_model_path) | ||
|
||
self.has_finished = True | ||
|
||
def _trainer_init_pre_middleware(self, trainer, *args, **kwargs): | ||
if self.node_rank != 0: | ||
return {}, args, kwargs | ||
|
||
from pytorch_lightning.serve import ServableModuleValidator | ||
|
||
callbacks = kwargs.get("callbacks", []) | ||
if self.sanity_serving: | ||
callbacks = callbacks + [ServableModuleValidator()] | ||
kwargs["callbacks"] = callbacks | ||
return {}, args, kwargs | ||
|
||
@property | ||
def is_running_in_cloud(self) -> bool: | ||
return "LIGHTNING_APP_STATE_URL" in os.environ | ||
manskx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class LightningTrainingComponent(LightningFlow): | ||
def __init__( | ||
self, | ||
script_path: str, | ||
script_args: Optional[Union[list, str]] = None, | ||
num_nodes: int = 1, | ||
cloud_compute: CloudCompute = CloudCompute("default"), | ||
sanity_serving: bool = False, | ||
script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner, | ||
**script_runner_kwargs, | ||
): | ||
"""This component enables performing distributed multi-node multi-device training. | ||
|
||
Example:: | ||
|
||
from lightning import LightningApp | ||
from lightning.app.components.training import LightningTrainingComponent | ||
from lightning.app.utilities.packaging.cloud_compute import CloudCompute | ||
|
||
app = LightningApp( | ||
LightningTrainingComponent( | ||
"train.py", | ||
num_nodes=2, | ||
cloud_compute=CloudCompute("gpu"), | ||
), | ||
) | ||
|
||
Arguments: | ||
script_path: Path to the script to be executed. | ||
script_args: The arguments to be pass to the script. | ||
num_nodes: Number of nodes. | ||
cloud_compute: The cloud compute object used in the cloud. | ||
sanity_serving: Whether to validate that the model correctly implements | ||
the ServableModule API | ||
""" | ||
super().__init__() | ||
self.ws = structures.List() | ||
self.has_initialized = False | ||
self.script_path = script_path | ||
self.script_args = script_args | ||
self.num_nodes = num_nodes | ||
self._cloud_compute = cloud_compute # TODO: Add support for cloudCompute | ||
self.sanity_serving = sanity_serving | ||
self._script_runner = script_runner | ||
self._script_runner_kwargs = script_runner_kwargs | ||
|
||
def run(self, **run_kwargs): | ||
if not self.has_initialized: | ||
for node_rank in range(self.num_nodes): | ||
self.ws.append( | ||
self._script_runner( | ||
script_path=self.script_path, | ||
script_args=self.script_args, | ||
cloud_compute=self._cloud_compute, | ||
node_rank=node_rank, | ||
sanity_serving=self.sanity_serving, | ||
num_nodes=self.num_nodes, | ||
**self._script_runner_kwargs, | ||
) | ||
) | ||
|
||
self.has_initialized = True | ||
|
||
for work in self.ws: | ||
if all(w.internal_ip for w in self.ws): | ||
internal_urls = [(w.internal_ip, w.port) for w in self.ws] | ||
work.run(internal_urls=internal_urls, **run_kwargs) | ||
if all(w.has_finished for w in self.ws): | ||
for w in self.ws: | ||
w.stop() | ||
else: | ||
work.run() | ||
|
||
@property | ||
def best_model_score(self) -> Optional[float]: | ||
return self.ws[0].best_model_score | ||
|
||
@property | ||
def best_model_paths(self) -> List[Optional[Path]]: | ||
return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.