Skip to content

Commit f8f2e3f

Browse files
Debug pipes overflow
ghstack-source-id: 251e5cf Pull Request resolved: #78952
1 parent f72d867 commit f8f2e3f

File tree

5 files changed

+73
-7
lines changed

5 files changed

+73
-7
lines changed

torch/utils/data/datapipes/datapipe.pyi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# classes/objects here, even though we are not injecting extra code into them at the moment.
55

66
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
7-
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar
7+
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar, Union
88
from torch.utils.data import Dataset, IterableDataset
99

1010
T_co = TypeVar('T_co', covariant=True)

torch/utils/data/datapipes/iter/combining.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torch.utils.data.datapipes._decorator import functional_datapipe
77
from torch.utils.data.datapipes.datapipe import IterDataPipe
8-
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
8+
from torch.utils.data.datapipes.utils.common import StreamWrapper, _check_lambda_fn
99

1010
__all__ = [
1111
"ConcaterIterDataPipe",
@@ -345,6 +345,7 @@ def _find_next(self, instance_id: int) -> T_co:
345345
value = next(self._datapipe_iterator)
346346
classification = self.classifier_fn(value)
347347
if classification is None and self.drop_none:
348+
StreamWrapper.close_streams(value)
348349
continue
349350
if classification is None or classification >= self.num_instances or classification < 0:
350351
raise ValueError(f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " +
@@ -510,8 +511,18 @@ def __init__(self, *datapipes: IterDataPipe):
510511
self.length = None
511512

512513
def __iter__(self) -> Iterator[Tuple[T_co]]:
513-
for data in zip(*self.datapipes):
514-
yield data
514+
iterators = [iter(datapipe) for datapipe in self.datapipes]
515+
try:
516+
for data in zip(*iterators):
517+
yield data
518+
finally:
519+
unused = []
520+
for iterator in iterators:
521+
unused += list(iterator)
522+
523+
# TODO(VitalyFedyunin): This should be Exception or warning when torchdata.debug is enabled
524+
for item in unused:
525+
StreamWrapper.close_streams(item)
515526

516527
def __len__(self) -> int:
517528
if self.length is not None:

torch/utils/data/datapipes/iter/selecting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch.utils.data.datapipes.datapipe import IterDataPipe
55
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
66
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, _deprecation_warning
7+
from torch.utils.data.datapipes.utils.common import StreamWrapper
8+
79

810
__all__ = ["FilterIterDataPipe", ]
911

@@ -78,6 +80,8 @@ def __iter__(self) -> Iterator[T_co]:
7880
filtered = self._returnIfTrue(data)
7981
if self._isNonEmpty(filtered):
8082
yield filtered
83+
else:
84+
StreamWrapper.close_streams(data)
8185

8286
def _returnIfTrue(self, data):
8387
condition = self._apply_filter_fn(data)

torch/utils/data/datapipes/iter/streamreader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@ def __iter__(self):
3232
while True:
3333
d = stream.read(self.chunk)
3434
if not d:
35+
stream.close()
3536
break
3637
yield (furl, d)

torch/utils/data/datapipes/utils/common.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import warnings
44

55
from io import IOBase
6-
from typing import Dict, Iterable, List, Tuple, Union, Optional
6+
from typing import Any, Dict, Iterable, List, Set, Tuple, Union, Optional
7+
78

89
from torch.utils.data._utils.serialization import DILL_AVAILABLE
910

@@ -179,13 +180,59 @@ class StreamWrapper:
179180
DataPipe operation like `FileOpener`. StreamWrapper would guarantee
180181
the wrapped file handler is closed when it's out of scope.
181182
'''
182-
def __init__(self, file_obj):
183+
session_streams: Set[Any] = set()
184+
185+
def __init__(self, file_obj, parent_stream=None, name=None):
183186
self.file_obj = file_obj
187+
self.child_counter = 0
188+
self.parent_stream = parent_stream
189+
self.close_on_last_child = False
190+
self.name = name
191+
if parent_stream is not None:
192+
if not isinstance(parent_stream, StreamWrapper):
193+
raise RuntimeError('Parent steam should be StreamWrapper, {} was given'.format(type(parent_stream)))
194+
parent_stream.child_counter += 1
195+
self.parent_stream = parent_stream
196+
StreamWrapper.session_streams.update(self)
197+
198+
@classmethod
199+
def close_streams(cls, v, depth=0):
200+
'''
201+
Traverse structure and attempts to close all found StreamWrappers on best effort basis.
202+
'''
203+
if depth > 10:
204+
return
205+
if isinstance(v, StreamWrapper):
206+
v.close()
207+
else:
208+
# Traverve only simple structures
209+
if isinstance(v, dict):
210+
for kk, vv in v.items():
211+
cls.close_streams(vv, depth=depth + 1)
212+
elif isinstance(v, list) or isinstance(v, tuple):
213+
for vv in v:
214+
cls.close_streams(vv, depth=depth + 1)
184215

185216
def __getattr__(self, name):
186217
file_obj = self.__dict__['file_obj']
187218
return getattr(file_obj, name)
188219

220+
def close(self, *args, **kwargs):
221+
StreamWrapper.session_streams.remove(self)
222+
if self.parent_stream is not None:
223+
self.parent_stream.child_counter -= 1
224+
if not self.parent_stream.child_counter and self.parent_stream.close_on_last_child:
225+
self.parent_stream.close()
226+
self.file_obj.close(*args, **kwargs)
227+
228+
def autoclose(self):
229+
'''
230+
Marks Steam to close automatically as soon as all child streams are closed.
231+
'''
232+
if self.child_counter == 0:
233+
self.close()
234+
self.close_on_last_child = True
235+
189236
def __dir__(self):
190237
attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys())
191238
attrs += dir(self.file_obj)
@@ -205,7 +252,10 @@ def __next__(self):
205252
return next(self.file_obj)
206253

207254
def __repr__(self):
208-
return f"StreamWrapper<{self.file_obj!r}>"
255+
if self.name is None:
256+
return f"StreamWrapper<{self.file_obj!r}>"
257+
else:
258+
return f"StreamWrapper<{self.name},{self.file_obj!r}>"
209259

210260
def __getstate__(self):
211261
return self.file_obj

0 commit comments

Comments
 (0)