Skip to content

[DataPipe] Automatically close parent streams and discarded streams #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torchdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@

from . import datapipes

janitor = datapipes.utils.janitor

try:
from .version import __version__ # noqa: F401
except ImportError:
pass

__all__ = ["datapipes"]
__all__ = [
"datapipes",
"janitor",
]

# Please keep this list sorted
assert __all__ == sorted(__all__)
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/bz2fileloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
try:
extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload]
new_pathname = pathname.rstrip(".bz2")
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
63 changes: 37 additions & 26 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

from torchdata.datapipes.utils.janitor import janitor

T_co = TypeVar("T_co", covariant=True)


Expand Down Expand Up @@ -81,33 +83,42 @@ def __init__(
def __iter__(self) -> Iterator:
ref_it = iter(self.ref_datapipe)
warn_once_flag = True
for data in self.source_datapipe:
key = self.key_fn(data)
while key not in self.buffer:
try:
ref_data = next(ref_it)
except StopIteration:
raise BufferError(
f"No matching key can be found from reference DataPipe for the data {data}. "
"Please consider increasing the buffer size."
)
ref_key = self.ref_key_fn(ref_data)
if ref_key in self.buffer:
raise ValueError("Duplicate key is found in reference DataPipe")
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
if warn_once_flag:
warn_once_flag = False
warnings.warn(
"Buffer reaches the upper limit, so reference key-data pair begins to "
"be removed from buffer in FIFO order. Please consider increase buffer size."
try:
for data in self.source_datapipe:
key = self.key_fn(data)
while key not in self.buffer:
try:
ref_data = next(ref_it)
except StopIteration:
raise BufferError(
f"No matching key can be found from reference DataPipe for the data {data}. "
"Please consider increasing the buffer size."
)
self.buffer.popitem(last=False)
self.buffer[ref_key] = ref_data
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
if self.keep_key:
yield key, res
else:
yield res
ref_key = self.ref_key_fn(ref_data)
if ref_key in self.buffer:
raise ValueError("Duplicate key is found in reference DataPipe")
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
if warn_once_flag:
warn_once_flag = False
warnings.warn(
"Buffer reaches the upper limit, so reference key-data pair begins to "
"be removed from buffer in FIFO order. Please consider increase buffer size."
)
self.buffer.popitem(last=False)
self.buffer[ref_key] = ref_data
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
if self.keep_key:
yield key, res
else:
yield res
finally:
for remaining in ref_it:
janitor(remaining)

# TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled
if len(self.buffer) > 0:
for k, v in self.buffer.items():
janitor(v)

def __len__(self) -> int:
return len(self.source_datapipe)
Expand Down
4 changes: 3 additions & 1 deletion torchdata/datapipes/iter/util/decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for path, file in self.source_datapipe:
file_type = self._detect_compression_type(path)
decompressor = self._DECOMPRESSORS[file_type]
yield path, StreamWrapper(decompressor(file))
yield path, StreamWrapper(decompressor(file), file, name=path)
if isinstance(file, StreamWrapper):
file.autoclose()


@functional_datapipe("extract")
Expand Down
4 changes: 3 additions & 1 deletion torchdata/datapipes/iter/util/rararchiveloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]:
inner_path = os.path.join(path, info.filename)
file_obj = rar.open(info)

yield inner_path, StreamWrapper(file_obj) # type: ignore[misc]
yield inner_path, StreamWrapper(file_obj, stream, name=path) # type: ignore[misc]
if isinstance(stream, StreamWrapper):
stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/tararchiveloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}")
raise tarfile.ExtractError
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/xzfileloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
try:
extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload]
new_pathname = pathname.rstrip(".xz")
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()

def __len__(self) -> int:
if self.length == -1:
Expand Down
5 changes: 4 additions & 1 deletion torchdata/datapipes/iter/util/ziparchiveloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
continue
extracted_fobj = zips.open(zipinfo)
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
except Exception as e:
warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
raise e
finally:
if isinstance(data_stream, StreamWrapper):
data_stream.autoclose()
# We are unable to close 'data_stream' here, because it needs to be available to use later

def __len__(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from torch.utils.data.datapipes.utils.common import StreamWrapper

from ._visualization import to_graph
from .janitor import janitor

__all__ = ["StreamWrapper", "to_graph"]
__all__ = ["StreamWrapper", "janitor", "to_graph"]
10 changes: 10 additions & 0 deletions torchdata/datapipes/utils/janitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torchdata.datapipes.utils import StreamWrapper


def janitor(obj):
"""
Invokes various `obj` cleanup procedures such as:
- Closing streams
"""
# TODO(VitalyFedyunin): We can also release caching locks here to allow filtering
StreamWrapper.close_streams(obj)