-
Notifications
You must be signed in to change notification settings - Fork 7.1k
close streams in prototype datasets #6647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6e6c31e
afb0ec2
d63214e
9393fce
02335f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import functools | ||
import io | ||
import pickle | ||
from collections import deque | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
@@ -11,10 +12,11 @@ | |
from torch.utils.data.graph import traverse_dps | ||
from torch.utils.data.graph_settings import get_all_graph_pipes | ||
from torchdata.datapipes.iter import ShardingFilter, Shuffler | ||
from torchdata.datapipes.utils import StreamWrapper | ||
from torchvision._utils import sequence_to_str | ||
from torchvision.prototype import datasets, transforms | ||
from torchvision.prototype import datasets, features, transforms | ||
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE | ||
from torchvision.prototype.features import Image, Label | ||
|
||
|
||
assert_samples_equal = functools.partial( | ||
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True | ||
|
@@ -25,6 +27,17 @@ def extract_datapipes(dp): | |
return get_all_graph_pipes(traverse_dps(dp)) | ||
|
||
|
||
def consume(iterator): | ||
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes | ||
deque(iterator, maxlen=0) | ||
|
||
|
||
def next_consume(iterator): | ||
item = next(iterator) | ||
consume(iterator) | ||
return item | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def test_home(mocker, tmp_path): | ||
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) | ||
|
@@ -66,7 +79,7 @@ def test_sample(self, dataset_mock, config): | |
dataset, _ = dataset_mock.load(config) | ||
|
||
try: | ||
sample = next(iter(dataset)) | ||
sample = next_consume(iter(dataset)) | ||
except StopIteration: | ||
raise AssertionError("Unable to draw any sample.") from None | ||
except Exception as error: | ||
|
@@ -84,22 +97,53 @@ def test_num_samples(self, dataset_mock, config): | |
|
||
assert len(list(dataset)) == mock_info["num_samples"] | ||
|
||
@pytest.fixture | ||
def log_session_streams(self): | ||
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
StreamWrapper.debug_unclosed_streams = True | ||
yield | ||
finally: | ||
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams | ||
|
||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
def test_no_vanilla_tensors(self, dataset_mock, config): | ||
def test_stream_closing(self, log_session_streams, dataset_mock, config): | ||
def make_msg_and_close(head): | ||
unclosed_streams = [] | ||
for stream in StreamWrapper.session_streams.keys(): | ||
unclosed_streams.append(repr(stream.file_obj)) | ||
stream.close() | ||
unclosed_streams = "\n".join(unclosed_streams) | ||
return f"{head}\n\n{unclosed_streams}" | ||
|
||
if StreamWrapper.session_streams: | ||
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:")) | ||
|
||
dataset, _ = dataset_mock.load(config) | ||
|
||
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} | ||
if vanilla_tensors: | ||
consume(iter(dataset)) | ||
|
||
if StreamWrapper.session_streams: | ||
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:")) | ||
|
||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
def test_no_simple_tensors(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a drive-by since I was already touching the line: the term "vanilla" tensor is no longer used. In the prototype transforms we use "simple tensor" now and also have |
||
if simple_tensors: | ||
raise AssertionError( | ||
f"The values of key(s) " | ||
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." | ||
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors." | ||
) | ||
|
||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
def test_transformable(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
next(iter(dataset.map(transforms.Identity()))) | ||
dataset = dataset.map(transforms.Identity()) | ||
|
||
consume(iter(dataset)) | ||
|
||
@parametrize_dataset_mocks(DATASET_MOCKS) | ||
def test_traversable(self, dataset_mock, config): | ||
|
@@ -131,7 +175,7 @@ def test_data_loader(self, dataset_mock, config, num_workers): | |
collate_fn=self._collate_fn, | ||
) | ||
|
||
next(iter(dl)) | ||
consume(dl) | ||
|
||
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also | ||
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 | ||
|
@@ -148,7 +192,7 @@ def test_has_annotations(self, dataset_mock, config, annotation_dp_type): | |
def test_save_load(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
sample = next(iter(dataset)) | ||
sample = next_consume(iter(dataset)) | ||
|
||
with io.BytesIO() as buffer: | ||
torch.save(sample, buffer) | ||
|
@@ -177,7 +221,7 @@ class TestQMNIST: | |
def test_extra_label(self, dataset_mock, config): | ||
dataset, _ = dataset_mock.load(config) | ||
|
||
sample = next(iter(dataset)) | ||
sample = next_consume(iter(dataset)) | ||
for key, type in ( | ||
("nist_hsf_series", int), | ||
("nist_writer_id", int), | ||
|
@@ -214,7 +258,7 @@ def test_sample_content(self, dataset_mock, config): | |
assert "image" in sample | ||
assert "label" in sample | ||
|
||
assert isinstance(sample["image"], Image) | ||
assert isinstance(sample["label"], Label) | ||
assert isinstance(sample["image"], features.Image) | ||
assert isinstance(sample["label"], features.Label) | ||
|
||
assert sample["image"].shape == (1, 16, 16) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,31 +49,35 @@ def __init__( | |
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) | ||
|
||
def _resources(self) -> List[OnlineResource]: | ||
archive = HttpResource( | ||
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", | ||
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", | ||
) | ||
extra_split = HttpResource( | ||
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt", | ||
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432", | ||
) | ||
return [archive, extra_split] | ||
resources = [ | ||
HttpResource( | ||
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", | ||
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", | ||
) | ||
] | ||
if self._split == "train_noval": | ||
resources.append( | ||
HttpResource( | ||
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt", | ||
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432", | ||
) | ||
) | ||
return resources # type: ignore[return-value] | ||
|
||
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: | ||
path = pathlib.Path(data[0]) | ||
parent, grandparent, *_ = path.parents | ||
|
||
if parent.name == "dataset": | ||
return 0 | ||
elif grandparent.name == "dataset": | ||
if grandparent.name == "dataset": | ||
if parent.name == "img": | ||
return 1 | ||
return 0 | ||
elif parent.name == "cls": | ||
return 2 | ||
else: | ||
return None | ||
else: | ||
return None | ||
return 1 | ||
|
||
if parent.name == "dataset" and self._split != "train_noval": | ||
return 2 | ||
|
||
return None | ||
|
||
def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: | ||
split_and_image_data, ann_data = data | ||
|
@@ -93,18 +97,24 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st | |
) | ||
|
||
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: | ||
archive_dp, extra_split_dp = resource_dps | ||
|
||
archive_dp = resource_dps[0] | ||
split_dp, images_dp, anns_dp = Demultiplexer( | ||
archive_dp, | ||
3, | ||
self._classify_archive, | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
drop_none=True, | ||
) | ||
if self._split == "train_noval": | ||
split_dp = extra_split_dp | ||
archive_dp, split_dp = resource_dps | ||
images_dp, anns_dp = Demultiplexer( | ||
archive_dp, | ||
2, | ||
self._classify_archive, | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
drop_none=True, | ||
) | ||
else: | ||
archive_dp = resource_dps[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What will happen with resource_dps[1] in this case? It is disconnected from the graph or remaining unconsumed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
images_dp, anns_dp, split_dp = Demultiplexer( | ||
archive_dp, | ||
3, | ||
self._classify_archive, | ||
buffer_size=INFINITE_BUFFER_SIZE, | ||
drop_none=True, | ||
) | ||
|
||
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) | ||
split_dp = LineReader(split_dp, decode=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is useful since we often just want the first sample, but need to make sure to still consume to avoid dangling streams.
list(iterator)
would also do the trick, but keeps everything in memory for no reason.