20
20
import tarfile
21
21
import tempfile
22
22
import urllib .request
23
+ from datetime import datetime
23
24
from importlib .util import module_from_spec , spec_from_file_location
24
25
from itertools import groupby
25
26
from types import ModuleType
@@ -150,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
150
151
... lines = [ln.rstrip() for ln in fp.readlines()]
151
152
>>> lines = replace_vars_with_imports(lines, import_path)
152
153
"""
154
+ copied = []
153
155
body , tracking , skip_offset = [], False , 0
154
156
for ln in lines :
155
157
offset = len (ln ) - len (ln .lstrip ())
@@ -160,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
160
162
if var :
161
163
name = var .groups ()[0 ]
162
164
# skip private or apply white-list for allowed vars
163
- if not name .startswith ("__" ) or name in ("__all__" ,):
165
+ if name not in copied and ( not name .startswith ("__" ) or name in ("__all__" ,) ):
164
166
body .append (f"{ ' ' * offset } from { import_path } import { name } # noqa: F401" )
167
+ copied .append (name )
165
168
tracking , skip_offset = True , offset
166
169
continue
167
170
if not tracking :
@@ -196,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
196
199
return body
197
200
198
201
202
+ def prune_func_calls (lines : List [str ]) -> List [str ]:
203
+ """Prune calling functions from a file, even multi-line.
204
+
205
+ >>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
206
+ >>> import_path = ".".join(["pytorch_lightning", "loggers"])
207
+ >>> with open(py_file, encoding="utf-8") as fp:
208
+ ... lines = [ln.rstrip() for ln in fp.readlines()]
209
+ >>> lines = prune_func_calls(lines)
210
+ """
211
+ body , tracking , score = [], False , 0
212
+ for ln in lines :
213
+ # catching callable
214
+ calling = re .match (r"^@?[\w_\d\.]+ *\(" , ln .lstrip ())
215
+ if calling and " import " not in ln :
216
+ tracking = True
217
+ score = 0
218
+ if tracking :
219
+ score += ln .count ("(" ) - ln .count (")" )
220
+ if score == 0 :
221
+ tracking = False
222
+ else :
223
+ body .append (ln )
224
+ return body
225
+
226
+
199
227
def prune_empty_statements (lines : List [str ]) -> List [str ]:
200
228
"""Prune emprty if/else and try/except.
201
229
@@ -270,6 +298,46 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
270
298
return body
271
299
272
300
301
+ def wrap_try_except (body : List [str ], pkg : str , ver : str ) -> List [str ]:
302
+ """Wrap the file with try/except for better traceability of import misalignment."""
303
+ not_empty = sum (1 for ln in body if ln )
304
+ if not_empty == 0 :
305
+ return body
306
+ body = ["try:" ] + [f" { ln } " if ln else "" for ln in body ]
307
+ body += [
308
+ "" ,
309
+ "except ImportError as err:" ,
310
+ "" ,
311
+ " from os import linesep" ,
312
+ f" from { pkg } import __version__" ,
313
+ f" msg = f'Your `lightning` package was built for `{ pkg } =={ ver } `," + " but you are running {__version__}'" ,
314
+ " raise type(err)(str(err) + linesep + msg)" ,
315
+ ]
316
+ return body
317
+
318
+
319
+ def parse_version_from_file (pkg_root : str ) -> str :
320
+ """Loading the package version from file."""
321
+ file_ver = os .path .join (pkg_root , "__version__.py" )
322
+ file_about = os .path .join (pkg_root , "__about__.py" )
323
+ if os .path .isfile (file_ver ):
324
+ ver = _load_py_module ("version" , file_ver ).version
325
+ elif os .path .isfile (file_about ):
326
+ ver = _load_py_module ("about" , file_about ).__version__
327
+ else : # this covers case you have build only meta-package so not additional source files are present
328
+ ver = ""
329
+ return ver
330
+
331
+
332
+ def prune_duplicate_lines (body ):
333
+ body_ = []
334
+ # drop duplicated lines
335
+ for ln in body :
336
+ if ln .lstrip () not in body_ or ln .lstrip () in (")" , "" ):
337
+ body_ .append (ln )
338
+ return body_
339
+
340
+
273
341
def create_meta_package (src_folder : str , pkg_name : str = "pytorch_lightning" , lit_name : str = "pytorch" ):
274
342
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
275
343
class implementations by cross-imports to the true package.
@@ -279,6 +347,7 @@ class implementations by cross-imports to the true package.
279
347
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
280
348
"""
281
349
package_dir = os .path .join (src_folder , pkg_name )
350
+ pkg_ver = parse_version_from_file (package_dir )
282
351
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
283
352
py_files = glob .glob (os .path .join (src_folder , pkg_name , "**" , "*.py" ), recursive = True )
284
353
for py_file in py_files :
@@ -298,41 +367,57 @@ class implementations by cross-imports to the true package.
298
367
logging .warning (f"unsupported file: { local_path } " )
299
368
continue
300
369
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
301
- body = prune_comments_docstrings (lines )
370
+ body = prune_comments_docstrings ([ ln . rstrip () for ln in lines ] )
302
371
if fname not in ("__init__.py" , "__main__.py" ):
303
372
body = prune_imports_callables (body )
304
- body = replace_block_with_imports ([ln .rstrip () for ln in body ], import_path , "class" )
305
- body = replace_block_with_imports (body , import_path , "def" )
306
- body = replace_block_with_imports (body , import_path , "async def" )
373
+ for key_word in ("class" , "def" , "async def" ):
374
+ body = replace_block_with_imports (body , import_path , key_word )
375
+ # TODO: fix reimporting which is artefact after replacing var assignment with import;
376
+ # after fixing , update CI by remove F811 from CI/check pkg
307
377
body = replace_vars_with_imports (body , import_path )
378
+ if fname not in ("__main__.py" ,):
379
+ body = prune_func_calls (body )
308
380
body_len = - 1
309
381
# in case of several in-depth statements
310
382
while body_len != len (body ):
311
383
body_len = len (body )
384
+ body = prune_duplicate_lines (body )
312
385
body = prune_empty_statements (body )
313
- # TODO: add try/catch wrapper for whole body,
386
+ # add try/catch wrapper for whole body,
314
387
# so when import fails it tells you what is the package version this meta package was generated for...
388
+ body = wrap_try_except (body , pkg_name , pkg_ver )
315
389
316
390
# todo: apply pre-commit formatting
391
+ # clean to many empty lines
317
392
body = [ln for ln , _group in groupby (body )]
318
- lines = []
319
393
# drop duplicated lines
320
- for ln in body :
321
- if ln + os .linesep not in lines or ln in (")" , "" ):
322
- lines .append (ln + os .linesep )
394
+ body = prune_duplicate_lines (body )
323
395
# compose the target file name
324
396
new_file = os .path .join (src_folder , "lightning" , lit_name , local_path )
325
397
os .makedirs (os .path .dirname (new_file ), exist_ok = True )
326
398
with open (new_file , "w" , encoding = "utf-8" ) as fp :
327
- fp .writelines (lines )
399
+ fp .writelines ([ln + os .linesep for ln in body ])
400
+
401
+
402
+ def set_version_today (fpath : str ) -> None :
403
+ """Replace the template date with today."""
404
+ with open (fpath ) as fp :
405
+ lines = fp .readlines ()
406
+
407
+ def _replace_today (ln ):
408
+ today = datetime .now ()
409
+ return ln .replace ("YYYY.-M.-D" , f"{ today .year } .{ today .month } .{ today .day } " )
410
+
411
+ lines = list (map (_replace_today , lines ))
412
+ with open (fpath , "w" ) as fp :
413
+ fp .writelines (lines )
328
414
329
415
330
416
def _download_frontend (root : str = _PROJECT_ROOT ):
331
417
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
332
418
directory."""
333
419
334
420
try :
335
- build_dir = "build"
336
421
frontend_dir = pathlib .Path (root , "src" , "lightning_app" , "ui" )
337
422
download_dir = tempfile .mkdtemp ()
338
423
@@ -342,7 +427,7 @@ def _download_frontend(root: str = _PROJECT_ROOT):
342
427
file = tarfile .open (fileobj = response , mode = "r|gz" )
343
428
file .extractall (path = download_dir )
344
429
345
- shutil .move (os .path .join (download_dir , build_dir ), frontend_dir )
430
+ shutil .move (os .path .join (download_dir , "build" ), frontend_dir )
346
431
print ("The Lightning UI has successfully been downloaded!" )
347
432
348
433
# If installing from source without internet connection, we don't want to break the installation
0 commit comments