Skip to content

Commit eb3a02d

Browse files
authored
Merge branch 'master' into fix/deepspeed_apx_lvl
2 parents 468c37d + aefb9ab commit eb3a02d

File tree

26 files changed

+604
-36
lines changed

26 files changed

+604
-36
lines changed

.actions/setup_tools.py

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tarfile
2121
import tempfile
2222
import urllib.request
23+
from datetime import datetime
2324
from importlib.util import module_from_spec, spec_from_file_location
2425
from itertools import groupby
2526
from types import ModuleType
@@ -150,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
150151
... lines = [ln.rstrip() for ln in fp.readlines()]
151152
>>> lines = replace_vars_with_imports(lines, import_path)
152153
"""
154+
copied = []
153155
body, tracking, skip_offset = [], False, 0
154156
for ln in lines:
155157
offset = len(ln) - len(ln.lstrip())
@@ -160,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
160162
if var:
161163
name = var.groups()[0]
162164
# skip private or apply white-list for allowed vars
163-
if not name.startswith("__") or name in ("__all__",):
165+
if name not in copied and (not name.startswith("__") or name in ("__all__",)):
164166
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
167+
copied.append(name)
165168
tracking, skip_offset = True, offset
166169
continue
167170
if not tracking:
@@ -196,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
196199
return body
197200

198201

202+
def prune_func_calls(lines: List[str]) -> List[str]:
203+
"""Prune calling functions from a file, even multi-line.
204+
205+
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
206+
>>> import_path = ".".join(["pytorch_lightning", "loggers"])
207+
>>> with open(py_file, encoding="utf-8") as fp:
208+
... lines = [ln.rstrip() for ln in fp.readlines()]
209+
>>> lines = prune_func_calls(lines)
210+
"""
211+
body, tracking, score = [], False, 0
212+
for ln in lines:
213+
# catching callable
214+
calling = re.match(r"^@?[\w_\d\.]+ *\(", ln.lstrip())
215+
if calling and " import " not in ln:
216+
tracking = True
217+
score = 0
218+
if tracking:
219+
score += ln.count("(") - ln.count(")")
220+
if score == 0:
221+
tracking = False
222+
else:
223+
body.append(ln)
224+
return body
225+
226+
199227
def prune_empty_statements(lines: List[str]) -> List[str]:
200228
"""Prune emprty if/else and try/except.
201229
@@ -270,6 +298,46 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
270298
return body
271299

272300

301+
def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]:
302+
"""Wrap the file with try/except for better traceability of import misalignment."""
303+
not_empty = sum(1 for ln in body if ln)
304+
if not_empty == 0:
305+
return body
306+
body = ["try:"] + [f" {ln}" if ln else "" for ln in body]
307+
body += [
308+
"",
309+
"except ImportError as err:",
310+
"",
311+
" from os import linesep",
312+
f" from {pkg} import __version__",
313+
f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'",
314+
" raise type(err)(str(err) + linesep + msg)",
315+
]
316+
return body
317+
318+
319+
def parse_version_from_file(pkg_root: str) -> str:
320+
"""Loading the package version from file."""
321+
file_ver = os.path.join(pkg_root, "__version__.py")
322+
file_about = os.path.join(pkg_root, "__about__.py")
323+
if os.path.isfile(file_ver):
324+
ver = _load_py_module("version", file_ver).version
325+
elif os.path.isfile(file_about):
326+
ver = _load_py_module("about", file_about).__version__
327+
else: # this covers case you have build only meta-package so not additional source files are present
328+
ver = ""
329+
return ver
330+
331+
332+
def prune_duplicate_lines(body):
333+
body_ = []
334+
# drop duplicated lines
335+
for ln in body:
336+
if ln.lstrip() not in body_ or ln.lstrip() in (")", ""):
337+
body_.append(ln)
338+
return body_
339+
340+
273341
def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
274342
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
275343
class implementations by cross-imports to the true package.
@@ -279,6 +347,7 @@ class implementations by cross-imports to the true package.
279347
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
280348
"""
281349
package_dir = os.path.join(src_folder, pkg_name)
350+
pkg_ver = parse_version_from_file(package_dir)
282351
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
283352
py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True)
284353
for py_file in py_files:
@@ -298,41 +367,57 @@ class implementations by cross-imports to the true package.
298367
logging.warning(f"unsupported file: {local_path}")
299368
continue
300369
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
301-
body = prune_comments_docstrings(lines)
370+
body = prune_comments_docstrings([ln.rstrip() for ln in lines])
302371
if fname not in ("__init__.py", "__main__.py"):
303372
body = prune_imports_callables(body)
304-
body = replace_block_with_imports([ln.rstrip() for ln in body], import_path, "class")
305-
body = replace_block_with_imports(body, import_path, "def")
306-
body = replace_block_with_imports(body, import_path, "async def")
373+
for key_word in ("class", "def", "async def"):
374+
body = replace_block_with_imports(body, import_path, key_word)
375+
# TODO: fix reimporting which is artefact after replacing var assignment with import;
376+
# after fixing , update CI by remove F811 from CI/check pkg
307377
body = replace_vars_with_imports(body, import_path)
378+
if fname not in ("__main__.py",):
379+
body = prune_func_calls(body)
308380
body_len = -1
309381
# in case of several in-depth statements
310382
while body_len != len(body):
311383
body_len = len(body)
384+
body = prune_duplicate_lines(body)
312385
body = prune_empty_statements(body)
313-
# TODO: add try/catch wrapper for whole body,
386+
# add try/catch wrapper for whole body,
314387
# so when import fails it tells you what is the package version this meta package was generated for...
388+
body = wrap_try_except(body, pkg_name, pkg_ver)
315389

316390
# todo: apply pre-commit formatting
391+
# clean to many empty lines
317392
body = [ln for ln, _group in groupby(body)]
318-
lines = []
319393
# drop duplicated lines
320-
for ln in body:
321-
if ln + os.linesep not in lines or ln in (")", ""):
322-
lines.append(ln + os.linesep)
394+
body = prune_duplicate_lines(body)
323395
# compose the target file name
324396
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)
325397
os.makedirs(os.path.dirname(new_file), exist_ok=True)
326398
with open(new_file, "w", encoding="utf-8") as fp:
327-
fp.writelines(lines)
399+
fp.writelines([ln + os.linesep for ln in body])
400+
401+
402+
def set_version_today(fpath: str) -> None:
403+
"""Replace the template date with today."""
404+
with open(fpath) as fp:
405+
lines = fp.readlines()
406+
407+
def _replace_today(ln):
408+
today = datetime.now()
409+
return ln.replace("YYYY.-M.-D", f"{today.year}.{today.month}.{today.day}")
410+
411+
lines = list(map(_replace_today, lines))
412+
with open(fpath, "w") as fp:
413+
fp.writelines(lines)
328414

329415

330416
def _download_frontend(root: str = _PROJECT_ROOT):
331417
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
332418
directory."""
333419

334420
try:
335-
build_dir = "build"
336421
frontend_dir = pathlib.Path(root, "src", "lightning_app", "ui")
337422
download_dir = tempfile.mkdtemp()
338423

@@ -342,7 +427,7 @@ def _download_frontend(root: str = _PROJECT_ROOT):
342427
file = tarfile.open(fileobj=response, mode="r|gz")
343428
file.extractall(path=download_dir)
344429

345-
shutil.move(os.path.join(download_dir, build_dir), frontend_dir)
430+
shutil.move(os.path.join(download_dir, "build"), frontend_dir)
346431
print("The Lightning UI has successfully been downloaded!")
347432

348433
# If installing from source without internet connection, we don't want to break the installation

.github/actions/pkg-check/action.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@ runs:
1414
run: pip install "twine==4.0.1" setuptools wheel flake8
1515
shell: bash
1616

17-
- name: Create package
17+
- name: Source check
1818
env:
1919
PACKAGE_NAME: ${{ inputs.pkg-name }}
2020
run: |
2121
python setup.py check --metadata --strict
22-
flake8 src/lightning/ --ignore E402,F401,E501,W391,E303
23-
python setup.py sdist bdist_wheel
22+
# TODO: fix reimporting (F811) which is aftefact after rplacing var assigne with import in meta package
23+
flake8 src/lightning/ --ignore E402,F401,E501,W391,E303,F811
24+
shell: bash
25+
26+
- name: Create package
27+
env:
28+
PACKAGE_NAME: ${{ inputs.pkg-name }}
29+
run: python setup.py sdist bdist_wheel
2430
shell: bash
2531

2632
- name: Check package

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,4 @@ src/lightning_app/ui/*
163163
*examples/template_react_ui*
164164
hars*
165165
artifacts/*
166+
*docs/examples*

docs/source-app/api_reference/components.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ ___________________
2020

2121
~python.popen.PopenPythonScript
2222
~python.tracer.TracerPythonScript
23+
~training.LightningTrainingComponent
2324
~serve.gradio.ServeGradio
2425
~serve.serve.ModelInferenceAPI

examples/app_multi_node/app.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from lightning import LightningApp
2+
from lightning.app.components.training import LightningTrainingComponent
3+
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
4+
5+
app = LightningApp(
6+
LightningTrainingComponent(
7+
"train.py",
8+
num_nodes=2,
9+
cloud_compute=CloudCompute("gpu-fast-multi"),
10+
),
11+
)

examples/app_multi_node/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from lightning.pytorch import Trainer
2+
from lightning.pytorch.demos.boring_classes import BoringModel
3+
4+
if __name__ == "__main__":
5+
model = BoringModel()
6+
trainer = Trainer(max_epochs=1)
7+
trainer.fit(model)

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
# https://packaging.python.org/guides/single-sourcing-package-version/
6060
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
6161
_PATH_ROOT = os.path.dirname(__file__)
62-
_PATH_SETUP = os.path.join(_PATH_ROOT, "src", _REAL_PKG_NAME or "lightning", "__setup__.py")
62+
_PATH_SRC = os.path.join(_PATH_ROOT, "src")
63+
_PATH_SETUP = os.path.join(_PATH_SRC, _REAL_PKG_NAME or "lightning", "__setup__.py")
6364

6465

6566
# Hardcode the env variable from time of package creation, otherwise it fails during installation
@@ -88,6 +89,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
8889
# engineer specific practices
8990
if __name__ == "__main__":
9091
_SETUP_TOOLS = _load_py_module(name="setup_tools", location=os.path.join(".actions", "setup_tools.py"))
92+
_SETUP_TOOLS.set_version_today(os.path.join(_PATH_SRC, "lightning", "__version__.py"))
9193
for lit_name, pkg_name in _PACKAGE_MAPPING.items():
9294
# fixme: if we run creation of meta pkg against stable we shall pull the source
9395
_SETUP_TOOLS.create_meta_package(os.path.join(_PATH_ROOT, "src"), pkg_name, lit_name)

0 commit comments

Comments
 (0)