Skip to content

Commit f14d6d4

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Automatically close parent streams and discarded streams (#560)
Summary: Pull Request resolved: #560 Test Plan: Imported from OSS Reviewed By: bearzx Differential Revision: D37625298 Pulled By: VitalyFedyunin fbshipit-source-id: b0636eb9fc6fc32bffb166912341d3dc90a4e056
1 parent 7abb164 commit f14d6d4

File tree

10 files changed

+86
-34
lines changed

10 files changed

+86
-34
lines changed

torchdata/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@
88

99
from . import datapipes
1010

11+
janitor = datapipes.utils.janitor
12+
1113
try:
1214
from .version import __version__ # noqa: F401
1315
except ImportError:
1416
pass
1517

16-
__all__ = ["datapipes"]
18+
__all__ = [
19+
"datapipes",
20+
"janitor",
21+
]
22+
23+
# Please keep this list sorted
24+
assert __all__ == sorted(__all__)

torchdata/datapipes/iter/util/bz2fileloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
5454
try:
5555
extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload]
5656
new_pathname = pathname.rstrip(".bz2")
57-
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
57+
yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc]
5858
except Exception as e:
5959
warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!")
6060
raise e
61+
finally:
62+
if isinstance(data_stream, StreamWrapper):
63+
data_stream.autoclose()
6164

6265
def __len__(self) -> int:
6366
if self.length == -1:

torchdata/datapipes/iter/util/combining.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
1212
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
1313

14+
from torchdata.datapipes.utils.janitor import janitor
15+
1416
T_co = TypeVar("T_co", covariant=True)
1517

1618

@@ -81,33 +83,42 @@ def __init__(
8183
def __iter__(self) -> Iterator:
8284
ref_it = iter(self.ref_datapipe)
8385
warn_once_flag = True
84-
for data in self.source_datapipe:
85-
key = self.key_fn(data)
86-
while key not in self.buffer:
87-
try:
88-
ref_data = next(ref_it)
89-
except StopIteration:
90-
raise BufferError(
91-
f"No matching key can be found from reference DataPipe for the data {data}. "
92-
"Please consider increasing the buffer size."
93-
)
94-
ref_key = self.ref_key_fn(ref_data)
95-
if ref_key in self.buffer:
96-
raise ValueError("Duplicate key is found in reference DataPipe")
97-
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
98-
if warn_once_flag:
99-
warn_once_flag = False
100-
warnings.warn(
101-
"Buffer reaches the upper limit, so reference key-data pair begins to "
102-
"be removed from buffer in FIFO order. Please consider increase buffer size."
86+
try:
87+
for data in self.source_datapipe:
88+
key = self.key_fn(data)
89+
while key not in self.buffer:
90+
try:
91+
ref_data = next(ref_it)
92+
except StopIteration:
93+
raise BufferError(
94+
f"No matching key can be found from reference DataPipe for the data {data}. "
95+
"Please consider increasing the buffer size."
10396
)
104-
self.buffer.popitem(last=False)
105-
self.buffer[ref_key] = ref_data
106-
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
107-
if self.keep_key:
108-
yield key, res
109-
else:
110-
yield res
97+
ref_key = self.ref_key_fn(ref_data)
98+
if ref_key in self.buffer:
99+
raise ValueError("Duplicate key is found in reference DataPipe")
100+
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
101+
if warn_once_flag:
102+
warn_once_flag = False
103+
warnings.warn(
104+
"Buffer reaches the upper limit, so reference key-data pair begins to "
105+
"be removed from buffer in FIFO order. Please consider increase buffer size."
106+
)
107+
self.buffer.popitem(last=False)
108+
self.buffer[ref_key] = ref_data
109+
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
110+
if self.keep_key:
111+
yield key, res
112+
else:
113+
yield res
114+
finally:
115+
for remaining in ref_it:
116+
janitor(remaining)
117+
118+
# TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled
119+
if len(self.buffer) > 0:
120+
for k, v in self.buffer.items():
121+
janitor(v)
111122

112123
def __len__(self) -> int:
113124
return len(self.source_datapipe)

torchdata/datapipes/iter/util/decompressor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
9797
for path, file in self.source_datapipe:
9898
file_type = self._detect_compression_type(path)
9999
decompressor = self._DECOMPRESSORS[file_type]
100-
yield path, StreamWrapper(decompressor(file))
100+
yield path, StreamWrapper(decompressor(file), file, name=path)
101+
if isinstance(file, StreamWrapper):
102+
file.autoclose()
101103

102104

103105
@functional_datapipe("extract")

torchdata/datapipes/iter/util/rararchiveloader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]:
107107
inner_path = os.path.join(path, info.filename)
108108
file_obj = rar.open(info)
109109

110-
yield inner_path, StreamWrapper(file_obj) # type: ignore[misc]
110+
yield inner_path, StreamWrapper(file_obj, stream, name=path) # type: ignore[misc]
111+
if isinstance(stream, StreamWrapper):
112+
stream.autoclose()
111113

112114
def __len__(self) -> int:
113115
if self.length == -1:

torchdata/datapipes/iter/util/tararchiveloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
6767
warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}")
6868
raise tarfile.ExtractError
6969
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
70-
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
70+
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
7171
except Exception as e:
7272
warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!")
7373
raise e
74+
finally:
75+
if isinstance(data_stream, StreamWrapper):
76+
data_stream.autoclose()
7477

7578
def __len__(self) -> int:
7679
if self.length == -1:

torchdata/datapipes/iter/util/xzfileloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
5656
try:
5757
extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload]
5858
new_pathname = pathname.rstrip(".xz")
59-
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
59+
yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=pathname) # type: ignore[misc]
6060
except Exception as e:
6161
warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!")
6262
raise e
63+
finally:
64+
if isinstance(data_stream, StreamWrapper):
65+
data_stream.autoclose()
6366

6467
def __len__(self) -> int:
6568
if self.length == -1:

torchdata/datapipes/iter/util/ziparchiveloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
6767
continue
6868
extracted_fobj = zips.open(zipinfo)
6969
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
70-
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
70+
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
7171
except Exception as e:
7272
warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
7373
raise e
74+
finally:
75+
if isinstance(data_stream, StreamWrapper):
76+
data_stream.autoclose()
7477
# We are unable to close 'data_stream' here, because it needs to be available to use later
7578

7679
def __len__(self) -> int:

torchdata/datapipes/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from torch.utils.data.datapipes.utils.common import StreamWrapper
88

99
from ._visualization import to_graph
10+
from .janitor import janitor
1011

11-
__all__ = ["StreamWrapper", "to_graph"]
12+
__all__ = ["StreamWrapper", "janitor", "to_graph"]

torchdata/datapipes/utils/janitor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchdata.datapipes.utils import StreamWrapper
8+
9+
10+
def janitor(obj):
11+
"""
12+
Invokes various `obj` cleanup procedures such as:
13+
- Closing streams
14+
"""
15+
# TODO(VitalyFedyunin): We can also release caching locks here to allow filtering
16+
StreamWrapper.close_streams(obj)

0 commit comments

Comments
 (0)