diff --git a/torchdata/__init__.py b/torchdata/__init__.py index 2f0f3c381..43aa7b7b0 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -8,9 +8,17 @@ from . import datapipes +janitor = datapipes.utils.janitor + try: from .version import __version__ # noqa: F401 except ImportError: pass -__all__ = ["datapipes"] +__all__ = [ + "datapipes", + "janitor", +] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/torchdata/datapipes/iter/util/bz2fileloader.py b/torchdata/datapipes/iter/util/bz2fileloader.py index 93f4739f7..442df0392 100644 --- a/torchdata/datapipes/iter/util/bz2fileloader.py +++ b/torchdata/datapipes/iter/util/bz2fileloader.py @@ -54,10 +54,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: try: extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload] new_pathname = pathname.rstrip(".bz2") - yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index a7148d26f..e9419b8c1 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -11,6 +11,8 @@ from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn +from torchdata.datapipes.utils.janitor import janitor + T_co = TypeVar("T_co", covariant=True) @@ -81,33 +83,42 @@ def __init__( def __iter__(self) -> Iterator: ref_it = iter(self.ref_datapipe) warn_once_flag = True - for data in self.source_datapipe: - key = self.key_fn(data) - while key not in self.buffer: - try: - ref_data = next(ref_it) - except StopIteration: - raise BufferError( - f"No matching key can be found from reference DataPipe for the data {data}. " - "Please consider increasing the buffer size." - ) - ref_key = self.ref_key_fn(ref_data) - if ref_key in self.buffer: - raise ValueError("Duplicate key is found in reference DataPipe") - if self.buffer_size is not None and len(self.buffer) > self.buffer_size: - if warn_once_flag: - warn_once_flag = False - warnings.warn( - "Buffer reaches the upper limit, so reference key-data pair begins to " - "be removed from buffer in FIFO order. Please consider increase buffer size." + try: + for data in self.source_datapipe: + key = self.key_fn(data) + while key not in self.buffer: + try: + ref_data = next(ref_it) + except StopIteration: + raise BufferError( + f"No matching key can be found from reference DataPipe for the data {data}. " + "Please consider increasing the buffer size." ) - self.buffer.popitem(last=False) - self.buffer[ref_key] = ref_data - res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key)) - if self.keep_key: - yield key, res - else: - yield res + ref_key = self.ref_key_fn(ref_data) + if ref_key in self.buffer: + raise ValueError("Duplicate key is found in reference DataPipe") + if self.buffer_size is not None and len(self.buffer) > self.buffer_size: + if warn_once_flag: + warn_once_flag = False + warnings.warn( + "Buffer reaches the upper limit, so reference key-data pair begins to " + "be removed from buffer in FIFO order. Please consider increase buffer size." + ) + self.buffer.popitem(last=False) + self.buffer[ref_key] = ref_data + res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key)) + if self.keep_key: + yield key, res + else: + yield res + finally: + for remaining in ref_it: + janitor(remaining) + + # TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled + if len(self.buffer) > 0: + for k, v in self.buffer.items(): + janitor(v) def __len__(self) -> int: return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/decompressor.py b/torchdata/datapipes/iter/util/decompressor.py index 3cc949bde..30e153b42 100644 --- a/torchdata/datapipes/iter/util/decompressor.py +++ b/torchdata/datapipes/iter/util/decompressor.py @@ -97,7 +97,9 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: for path, file in self.source_datapipe: file_type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[file_type] - yield path, StreamWrapper(decompressor(file)) + yield path, StreamWrapper(decompressor(file), file, name=path) + if isinstance(file, StreamWrapper): + file.autoclose() @functional_datapipe("extract") diff --git a/torchdata/datapipes/iter/util/rararchiveloader.py b/torchdata/datapipes/iter/util/rararchiveloader.py index 5b5546316..c657b1d90 100644 --- a/torchdata/datapipes/iter/util/rararchiveloader.py +++ b/torchdata/datapipes/iter/util/rararchiveloader.py @@ -107,7 +107,9 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]: inner_path = os.path.join(path, info.filename) file_obj = rar.open(info) - yield inner_path, StreamWrapper(file_obj) # type: ignore[misc] + yield inner_path, StreamWrapper(file_obj, stream, name=path) # type: ignore[misc] + if isinstance(stream, StreamWrapper): + stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/tararchiveloader.py b/torchdata/datapipes/iter/util/tararchiveloader.py index ccb9e0c0f..4c8adef19 100644 --- a/torchdata/datapipes/iter/util/tararchiveloader.py +++ b/torchdata/datapipes/iter/util/tararchiveloader.py @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}") raise tarfile.ExtractError inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name)) - yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/xzfileloader.py b/torchdata/datapipes/iter/util/xzfileloader.py index ac5c46beb..450ee05b8 100644 --- a/torchdata/datapipes/iter/util/xzfileloader.py +++ b/torchdata/datapipes/iter/util/xzfileloader.py @@ -56,10 +56,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: try: extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload] new_pathname = pathname.rstrip(".xz") - yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() def __len__(self) -> int: if self.length == -1: diff --git a/torchdata/datapipes/iter/util/ziparchiveloader.py b/torchdata/datapipes/iter/util/ziparchiveloader.py index ae0c8f836..d61e061d9 100644 --- a/torchdata/datapipes/iter/util/ziparchiveloader.py +++ b/torchdata/datapipes/iter/util/ziparchiveloader.py @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: continue extracted_fobj = zips.open(zipinfo) inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename)) - yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc] + yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] except Exception as e: warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!") raise e + finally: + if isinstance(data_stream, StreamWrapper): + data_stream.autoclose() # We are unable to close 'data_stream' here, because it needs to be available to use later def __len__(self) -> int: diff --git a/torchdata/datapipes/utils/__init__.py b/torchdata/datapipes/utils/__init__.py index 889fd4d08..c74e8f702 100644 --- a/torchdata/datapipes/utils/__init__.py +++ b/torchdata/datapipes/utils/__init__.py @@ -7,5 +7,6 @@ from torch.utils.data.datapipes.utils.common import StreamWrapper from ._visualization import to_graph +from .janitor import janitor -__all__ = ["StreamWrapper", "to_graph"] +__all__ = ["StreamWrapper", "janitor", "to_graph"] diff --git a/torchdata/datapipes/utils/janitor.py b/torchdata/datapipes/utils/janitor.py new file mode 100644 index 000000000..bc649123f --- /dev/null +++ b/torchdata/datapipes/utils/janitor.py @@ -0,0 +1,10 @@ +from torchdata.datapipes.utils import StreamWrapper + + +def janitor(obj): + """ + Invokes various `obj` cleanup procedures such as: + - Closing streams + """ + # TODO(VitalyFedyunin): We can also release caching locks here to allow filtering + StreamWrapper.close_streams(obj)