Skip to content

Commit c631c8c

Browse files
[DataPipe] Automatically close parent streams and discarded streams
ghstack-source-id: b0a21fa Pull Request resolved: #560
1 parent 242ec0d commit c631c8c

File tree

10 files changed

+53
-8
lines changed

10 files changed

+53
-8
lines changed

torchdata/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@
88

99
from . import datapipes
1010

11+
janitor = datapipes.utils.janitor
12+
1113
try:
1214
from .version import __version__ # noqa: F401
1315
except ImportError:
1416
pass
1517

16-
__all__ = ["datapipes"]
18+
__all__ = [
19+
"datapipes",
20+
"janitor",
21+
]
22+
23+
# Please keep this list sorted
24+
assert __all__ == sorted(__all__)

torchdata/datapipes/iter/util/bz2fileloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
5454
try:
5555
extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload]
5656
new_pathname = pathname.rstrip(".bz2")
57-
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
57+
yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc]
5858
except Exception as e:
5959
warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!")
6060
raise e
61+
finally:
62+
if isinstance(data_stream, StreamWrapper):
63+
data_stream.autoclose()
6164

6265
def __len__(self) -> int:
6366
if self.length == -1:

torchdata/datapipes/iter/util/combining.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
1212
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
1313

14+
from torchdata.datapipes.utils.janitor import janitor
15+
1416
T_co = TypeVar("T_co", covariant=True)
1517

1618

@@ -109,6 +111,14 @@ def __iter__(self) -> Iterator:
109111
else:
110112
yield res
111113

114+
for remaining in ref_it:
115+
janitor(remaining)
116+
117+
# TODO(VItalyFedyunin): This should be Exception or warn when debug mode is enabled
118+
if len(self.buffer) > 0:
119+
for k, v in self.buffer.items():
120+
janitor(v)
121+
112122
def __len__(self) -> int:
113123
return len(self.source_datapipe)
114124

torchdata/datapipes/iter/util/decompressor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
9797
for path, file in self.source_datapipe:
9898
file_type = self._detect_compression_type(path)
9999
decompressor = self._DECOMPRESSORS[file_type]
100-
yield path, StreamWrapper(decompressor(file))
100+
yield path, StreamWrapper(decompressor(file), file, name=path)
101+
if isinstance(file, StreamWrapper):
102+
file.autoclose()
101103

102104

103105
@functional_datapipe("extract")

torchdata/datapipes/iter/util/rararchiveloader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]:
107107
inner_path = os.path.join(path, info.filename)
108108
file_obj = rar.open(info)
109109

110-
yield inner_path, StreamWrapper(file_obj) # type: ignore[misc]
110+
yield inner_path, StreamWrapper(file_obj, stream) # type: ignore[misc]
111+
if isinstance(stream, StreamWrapper):
112+
stream.autoclose()
111113

112114
def __len__(self) -> int:
113115
if self.length == -1:

torchdata/datapipes/iter/util/tararchiveloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
6767
warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}")
6868
raise tarfile.ExtractError
6969
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
70-
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
70+
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
7171
except Exception as e:
7272
warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!")
7373
raise e
74+
finally:
75+
if isinstance(data_stream, StreamWrapper):
76+
data_stream.autoclose()
7477

7578
def __len__(self) -> int:
7679
if self.length == -1:

torchdata/datapipes/iter/util/xzfileloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
5656
try:
5757
extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload]
5858
new_pathname = pathname.rstrip(".xz")
59-
yield new_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
59+
yield new_pathname, StreamWrapper(extracted_fobj, data_stream) # type: ignore[misc]
6060
except Exception as e:
6161
warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!")
6262
raise e
63+
finally:
64+
if isinstance(data_stream, StreamWrapper):
65+
data_stream.autoclose()
6366

6467
def __len__(self) -> int:
6568
if self.length == -1:

torchdata/datapipes/iter/util/ziparchiveloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,13 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
6767
continue
6868
extracted_fobj = zips.open(zipinfo)
6969
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
70-
yield inner_pathname, StreamWrapper(extracted_fobj) # type: ignore[misc]
70+
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc]
7171
except Exception as e:
7272
warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
7373
raise e
74+
finally:
75+
if isinstance(data_stream, StreamWrapper):
76+
data_stream.autoclose()
7477
# We are unable to close 'data_stream' here, because it needs to be available to use later
7578

7679
def __len__(self) -> int:

torchdata/datapipes/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from torch.utils.data.datapipes.utils.common import StreamWrapper
88

99
from ._visualization import to_graph
10+
from .janitor import janitor
1011

11-
__all__ = ["StreamWrapper", "to_graph"]
12+
__all__ = ["StreamWrapper", "janitor", "to_graph"]

torchdata/datapipes/utils/janitor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from torchdata.datapipes.utils import StreamWrapper
2+
3+
4+
def janitor(obj):
5+
"""
6+
Invokes various `obj` cleanup procedures such as:
7+
- Closing streams
8+
"""
9+
# TODO(VitalyFedyunin): We can also release caching locks here to allow filtering
10+
StreamWrapper.close_streams(obj)

0 commit comments

Comments
 (0)