-
Notifications
You must be signed in to change notification settings - Fork 7.1k
close streams in prototype datasets #6647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
6e6c31e
afb0ec2
d63214e
9393fce
02335f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import functools | ||
import io | ||
import pickle | ||
from collections import deque | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
@@ -11,10 +12,11 @@ | |
from torch.utils.data.graph import traverse | ||
from torch.utils.data.graph_settings import get_all_graph_pipes | ||
from torchdata.datapipes.iter import ShardingFilter, Shuffler | ||
from torchdata.datapipes.utils import StreamWrapper | ||
from torchvision._utils import sequence_to_str | ||
from torchvision.prototype import datasets, transforms | ||
from torchvision.prototype import datasets, features, transforms | ||
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE | ||
from torchvision.prototype.features import Image, Label | ||
|
||
|
||
assert_samples_equal = functools.partial( | ||
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True | ||
|
@@ -25,6 +27,17 @@ def extract_datapipes(dp): | |
return get_all_graph_pipes(traverse(dp, only_datapipe=True)) | ||
|
||
|
||
def consume(iterator): | ||
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes | ||
deque(iterator, maxlen=0) | ||
|
||
|
||
def next_consume(iterator): | ||
item = next(iterator) | ||
consume(iterator) | ||
return item | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def test_home(mocker, tmp_path): | ||
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) | ||
|
@@ -66,7 +79,7 @@ def test_sample(self, dataset_mock, config): | |
dataset, _ = dataset_mock.load(config) | ||
|
||
try: | ||
sample = next(iter(dataset)) | ||
sample = next_consume(iter(dataset)) | ||
except StopIteration: | ||
raise AssertionError("Unable to draw any sample.") from None | ||
except Exception as error: | ||
|
@@ -84,22 +97,53 @@ def test_num_samples(self, dataset_mock, config): | |
|
||
assert len(list(dataset)) == mock_info["num_samples"] | ||
|
||
@pytest.fixture | ||
def log_session_streams(self): | ||
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
StreamWrapper.debug_unclosed_streams = True | ||
yield | ||
finally: | ||
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams | ||
|
||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
def test_no_vanilla_tensors(self, dataset_mock, config): | ||
def test_stream_closing(self, log_session_streams, dataset_mock, config): | ||
def make_msg_and_close(head): | ||
unclosed_streams = [] | ||
for stream in StreamWrapper.session_streams.keys(): | ||
unclosed_streams.append(repr(stream.file_obj)) | ||
stream.close() | ||
unclosed_streams = "\n".join(unclosed_streams) | ||
return f"{head}\n\n{unclosed_streams}" | ||
|
||
if StreamWrapper.session_streams: | ||
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:")) | ||
|
||
dataset, _ = dataset_mock.load(config) | ||
|
||
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} | ||
if vanilla_tensors: | ||
consume(iter(dataset)) | ||
|
||
if StreamWrapper.session_streams: | ||
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:")) | ||
|
||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
def test_no_simple_tensors(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a drive-by since I was already touching the line: the term "vanilla" tensor is no longer used. In the prototype transforms we use "simple tensor" now and also have |
||
if simple_tensors: | ||
raise AssertionError( | ||
f"The values of key(s) " | ||
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." | ||
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors." | ||
) | ||
|
||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
def test_transformable(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
next(iter(dataset.map(transforms.Identity()))) | ||
dataset = dataset.map(transforms.Identity()) | ||
|
||
consume(iter(dataset)) | ||
|
||
@pytest.mark.parametrize("only_datapipe", [False, True]) | ||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
|
@@ -132,7 +176,7 @@ def test_data_loader(self, dataset_mock, config, num_workers): | |
collate_fn=self._collate_fn, | ||
) | ||
|
||
next(iter(dl)) | ||
consume(dl) | ||
|
||
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also | ||
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 | ||
|
@@ -149,7 +193,7 @@ def test_has_annotations(self, dataset_mock, config, annotation_dp_type): | |
def test_save_load(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
sample = next(iter(dataset)) | ||
sample = next_consume(iter(dataset)) | ||
|
||
with io.BytesIO() as buffer: | ||
torch.save(sample, buffer) | ||
|
@@ -178,7 +222,7 @@ class TestQMNIST: | |
def test_extra_label(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
sample = next(iter(dataset)) | ||
sample = next_consume(iter(dataset)) | ||
for key, type in ( | ||
("nist_hsf_series", int), | ||
("nist_writer_id", int), | ||
|
@@ -215,7 +259,7 @@ def test_sample_content(self, dataset_mock, config): | |
assert "image" in sample | ||
assert "label" in sample | ||
|
||
assert isinstance(sample["image"], Image) | ||
assert isinstance(sample["label"], Label) | ||
assert isinstance(sample["image"], features.Image) | ||
assert isinstance(sample["label"], features.Label) | ||
|
||
assert sample["image"].shape == (1, 16, 16) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,8 +103,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, | |
buffer_size=INFINITE_BUFFER_SIZE, | ||
drop_none=True, | ||
) | ||
if self._split == "train_noval": | ||
split_dp = extra_split_dp | ||
split_dp, to_be_closed_dp = ( | ||
(extra_split_dp, split_dp) if self._split == "train_noval" else (split_dp, extra_split_dp) | ||
) | ||
for _, file in to_be_closed_dp: | ||
file.close() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is somewhat problematic in the Is there an idiom to "mark" a datapipe to be closed at runtime even if we don't return the datapipe? The only thing I came up with is changing the classifier function of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is incorrect to consume Datapipe in construction time. for _, file in to_be_closed_dp:
file.close() It will not affect executed graph. Ideally, we want to have something like dp.close(), which will effectively remove dangling pieces of the graph. But for now you can either use trick like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't think that is fully correct for our case. At construction, we load a couple of files and weave them together into one datapipe. These files are always loaded unconditionally, but for some configurations not all of the files are needed. So we should be able to simply close them during the construction of the dataset datapipe, since they will never make it in the graph, correct? I agree, we shouldn't do this for datapipes that stem for a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If possible, I think that is the better solution, since the other still iterates over all items when the loop should actually be done. I guess that could be irritating as well. I implemented branching in afb0ec2. PTAL |
||
|
||
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) | ||
split_dp = LineReader(split_dp, decode=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is useful since we often just want the first sample, but need to make sure to still consume to avoid dangling streams.
list(iterator)
would also do the trick, but keeps everything in memory for no reason.