From 2053e954df87e46301ed6d197976bb48073e9f04 Mon Sep 17 00:00:00 2001 From: Vitaly Fedyunin Date: Wed, 29 Jun 2022 14:04:43 -0400 Subject: [PATCH 1/4] [DataPipe] Automatically close parent streams and discarded streams [ghstack-poisoned] --- torchdata/__init__.py | 7 ++++++- torchdata/datapipes/iter/util/bz2fileloader.py | 5 ++++- torchdata/datapipes/iter/util/combining.py | 10 ++++++++++ torchdata/datapipes/iter/util/decompressor.py | 4 +++- torchdata/datapipes/iter/util/rararchiveloader.py | 4 +++- torchdata/datapipes/iter/util/tararchiveloader.py | 5 ++++- torchdata/datapipes/iter/util/xzfileloader.py | 5 ++++- torchdata/datapipes/iter/util/ziparchiveloader.py | 5 ++++- torchdata/datapipes/utils/__init__.py | 3 ++- torchdata/datapipes/utils/janitor.py | 10 ++++++++++ 10 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 torchdata/datapipes/utils/janitor.py diff --git a/torchdata/__init__.py b/torchdata/__init__.py index 2f0f3c381..d304f9a26 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -8,9 +8,14 @@ from . import datapipes +janitor = datapipes.utils.janitor + try: from .version import __version__ # noqa: F401 except ImportError: pass -__all__ = ["datapipes"] +__all__ = [ + "datapipes", + "janitor", +] diff --git a/torchdata/datapipes/iter/util/bz2fileloader.py b/torchdata/datapipes/iter/util/bz2fileloader.py index 93f4739f7..442df0392 100644 --- a/torchdata/datapipes/iter/util/bz2fileloader.py +++ b/torchdata/datapipes/iter/util/bz2fileloader.py @@ -54,10 +54,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: try: extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload] new_pathname = pathname.rstrip(".bz2") - yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index a7148d26f..ea07248a7 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -11,6 +11,8 @@ from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn +from torchdata.datapipes.utils.janitor import janitor + T_co = TypeVar("T_co", covariant=True) @@ -109,6 +111,14 @@ def __iter__(self) -> Iterator: else: yield res + for remaining in ref_it: + janitor(remaining) + + # TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled + if len(self.buffer) > 0: + for k, v in self.buffer.items(): + janitor(v) + def __len__(self) -> int: return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/decompressor.py b/torchdata/datapipes/iter/util/decompressor.py index 3cc949bde..30e153b42 100644 --- a/torchdata/datapipes/iter/util/decompressor.py +++ b/torchdata/datapipes/iter/util/decompressor.py @@ -97,7 +97,9 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: for path, file in self.source_datapipe: file_type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[file_type] - yield path, StreamWrapper(decompressor(file)) + yield path, StreamWrapper(decompressor(file), file, name=path) + if isinstance(file, StreamWrapper): + file.autoclose() @functional_datapipe("extract") diff --git a/torchdata/datapipes/iter/util/rararchiveloader.py b/torchdata/datapipes/iter/util/rararchiveloader.py index 5b5546316..d7c805fc4 100644 --- a/torchdata/datapipes/iter/util/rararchiveloader.py +++ b/torchdata/datapipes/iter/util/rararchiveloader.py @@ -107,7 +107,9 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]: inner_path = os.path.join(path, info.filename) file_obj = rar.open(info) - yield inner_path, StreamWrapper(file_obj) # type: ignore[misc] + yield inner_path, StreamWrapper(file_obj, stream) # type: ignore[misc] + if isinstance(stream, StreamWrapper): + stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/tararchiveloader.py b/torchdata/datapipes/iter/util/tararchiveloader.py index ccb9e0c0f..4c8adef19 100644 --- a/torchdata/datapipes/iter/util/tararchiveloader.py +++ b/torchdata/datapipes/iter/util/tararchiveloader.py @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}") raise tarfile.ExtractError inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name)) - yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/xzfileloader.py b/torchdata/datapipes/iter/util/xzfileloader.py index ac5c46beb..637560cce 100644 --- a/torchdata/datapipes/iter/util/xzfileloader.py +++ b/torchdata/datapipes/iter/util/xzfileloader.py @@ -56,10 +56,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: try: extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload] new_pathname = pathname.rstrip(".xz") - yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield new_pathname, StreamWrapper(extracted_fobj, data_stream) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/ziparchiveloader.py b/torchdata/datapipes/iter/util/ziparchiveloader.py index ae0c8f836..d61e061d9 100644 --- a/torchdata/datapipes/iter/util/ziparchiveloader.py +++ b/torchdata/datapipes/iter/util/ziparchiveloader.py @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: continue extracted_fobj = zips.open(zipinfo) inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename)) - yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() # We are unable to close 'data_stream' here, because it needs to be available to use later def __len__(self) -> int: diff --git a/torchdata/datapipes/utils/__init__.py b/torchdata/datapipes/utils/__init__.py index 889fd4d08..c74e8f702 100644 --- a/torchdata/datapipes/utils/__init__.py +++ b/torchdata/datapipes/utils/__init__.py @@ -7,5 +7,6 @@ from torch.utils.data.datapipes.utils.common import StreamWrapper from ._visualization import to_graph +from .janitor import janitor -__all__ = ["StreamWrapper", "to_graph"] +__all__ = ["StreamWrapper", "janitor", "to_graph"] diff --git a/torchdata/datapipes/utils/janitor.py b/torchdata/datapipes/utils/janitor.py new file mode 100644 index 000000000..bc649123f --- /dev/null +++ b/torchdata/datapipes/utils/janitor.py @@ -0,0 +1,10 @@ +from torchdata.datapipes.utils import StreamWrapper + + +def janitor(obj): + """ + Invokes various `obj` cleanup procedures such as: + - Closing streams + """ + # TODO(VitalyFedyunin): We can also release caching locks here to allow filtering + StreamWrapper.close_streams(obj) From 5f743a38c7811eff5a91c41487052425ba07e9cf Mon Sep 17 00:00:00 2001 From: Vitaly Fedyunin Date: Thu, 30 Jun 2022 12:32:04 -0400 Subject: [PATCH 2/4] Update on "[DataPipe] Automatically close parent streams and discarded streams" [ghstack-poisoned] --- torchdata/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchdata/__init__.py b/torchdata/__init__.py index d304f9a26..43aa7b7b0 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -19,3 +19,6 @@ "datapipes", "janitor", ] + +# Please keep this list sorted +assert __all__ == sorted(__all__) From 7cf380594f8c668fd6a16d4a5dc60c0badfeec8f Mon Sep 17 00:00:00 2001 From: Vitaly Fedyunin Date: Thu, 30 Jun 2022 15:30:41 -0400 Subject: [PATCH 3/4] Update on "[DataPipe] Automatically close parent streams and discarded streams" [ghstack-poisoned] --- torchdata/datapipes/iter/util/rararchiveloader.py | 2 +- torchdata/datapipes/iter/util/xzfileloader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchdata/datapipes/iter/util/rararchiveloader.py b/torchdata/datapipes/iter/util/rararchiveloader.py index d7c805fc4..c657b1d90 100644 --- a/torchdata/datapipes/iter/util/rararchiveloader.py +++ b/torchdata/datapipes/iter/util/rararchiveloader.py @@ -107,7 +107,7 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]: inner_path = os.path.join(path, info.filename) file_obj = rar.open(info) - yield inner_path, StreamWrapper(file_obj, stream) # type: ignore[misc] + yield inner_path, StreamWrapper(file_obj, stream, name=path) # type: ignore[misc] if isinstance(stream, StreamWrapper): stream.autoclose() diff --git a/torchdata/datapipes/iter/util/xzfileloader.py b/torchdata/datapipes/iter/util/xzfileloader.py index 637560cce..450ee05b8 100644 --- a/torchdata/datapipes/iter/util/xzfileloader.py +++ b/torchdata/datapipes/iter/util/xzfileloader.py @@ -56,7 +56,7 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: try: extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload] new_pathname = pathname.rstrip(".xz") - yield new_pathname, StreamWrapper(extracted_fobj, data_stream) # type: ignore[misc] + yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!") raise e From d0e51a59b283d49c53baa20612d60850d35a751b Mon Sep 17 00:00:00 2001 From: Vitaly Fedyunin Date: Tue, 5 Jul 2022 16:19:55 -0400 Subject: [PATCH 4/4] Update on "[DataPipe] Automatically close parent streams and discarded streams" [ghstack-poisoned] --- torchdata/datapipes/iter/util/combining.py | 69 +++++++++++----------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index ea07248a7..e9419b8c1 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -83,41 +83,42 @@ def __init__( def __iter__(self) -> Iterator: ref_it = iter(self.ref_datapipe) warn_once_flag = True - for data in self.source_datapipe: - key = self.key_fn(data) - while key not in self.buffer: - try: - ref_data = next(ref_it) - except StopIteration: - raise BufferError( - f"No matching key can be found from reference DataPipe for the data {data}. " - "Please consider increasing the buffer size." - ) - ref_key = self.ref_key_fn(ref_data) - if ref_key in self.buffer: - raise ValueError("Duplicate key is found in reference DataPipe") - if self.buffer_size is not None and len(self.buffer) > self.buffer_size: - if warn_once_flag: - warn_once_flag = False - warnings.warn( - "Buffer reaches the upper limit, so reference key-data pair begins to " - "be removed from buffer in FIFO order. Please consider increase buffer size." + try: + for data in self.source_datapipe: + key = self.key_fn(data) + while key not in self.buffer: + try: + ref_data = next(ref_it) + except StopIteration: + raise BufferError( + f"No matching key can be found from reference DataPipe for the data {data}. " + "Please consider increasing the buffer size." ) - self.buffer.popitem(last=False) - self.buffer[ref_key] = ref_data - res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key)) - if self.keep_key: - yield key, res - else: - yield res - - for remaining in ref_it: - janitor(remaining) - - # TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled - if len(self.buffer) > 0: - for k, v in self.buffer.items(): - janitor(v) + ref_key = self.ref_key_fn(ref_data) + if ref_key in self.buffer: + raise ValueError("Duplicate key is found in reference DataPipe") + if self.buffer_size is not None and len(self.buffer) > self.buffer_size: + if warn_once_flag: + warn_once_flag = False + warnings.warn( + "Buffer reaches the upper limit, so reference key-data pair begins to " + "be removed from buffer in FIFO order. Please consider increase buffer size." + ) + self.buffer.popitem(last=False) + self.buffer[ref_key] = ref_data + res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key)) + if self.keep_key: + yield key, res + else: + yield res + finally: + for remaining in ref_it: + janitor(remaining) + + # TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled + if len(self.buffer) > 0: + for k, v in self.buffer.items(): + janitor(v) def __len__(self) -> int: return len(self.source_datapipe)