Skip to content

close streams in prototype datasets #6647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,15 +661,15 @@ class SBDMockData:
_NUM_CATEGORIES = 20

@classmethod
def _make_split_files(cls, root_map):
ids_map = {
split: [f"2008_{idx:06d}" for idx in idcs]
for split, idcs in (
("train", [0, 1, 2]),
("train_noval", [0, 2]),
("val", [3]),
)
}
def _make_split_files(cls, root_map, *, split):
splits_and_idcs = [
("train", [0, 1, 2]),
("val", [3]),
]
if split == "train_noval":
splits_and_idcs.append(("train_noval", [0, 2]))

ids_map = {split: [f"2008_{idx:06d}" for idx in idcs] for split, idcs in splits_and_idcs}

for split, ids in ids_map.items():
with open(root_map[split] / f"{split}.txt", "w") as fh:
Expand Down Expand Up @@ -710,25 +710,27 @@ def _make_segmentation(cls, size):
return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy()

@classmethod
def generate(cls, root):
def generate(cls, root, *, split):
archive_folder = root / "benchmark_RELEASE"
dataset_folder = archive_folder / "dataset"
dataset_folder.mkdir(parents=True, exist_ok=True)

ids, num_samples_map = cls._make_split_files(defaultdict(lambda: dataset_folder, {"train_noval": root}))
ids, num_samples_map = cls._make_split_files(
defaultdict(lambda: dataset_folder, {"train_noval": root}), split=split
)
sizes = cls._make_anns_folder(dataset_folder, "cls", ids)
create_image_folder(
dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx]
)

make_tar(root, "benchmark.tgz", archive_folder, compression="gz")

return num_samples_map
return num_samples_map[split]


@register_mock(configs=combinations_grid(split=("train", "val", "train_noval")))
def sbd(root, config):
return SBDMockData.generate(root)[config["split"]]
return SBDMockData.generate(root, split=config["split"])


@register_mock(configs=[dict()])
Expand Down
70 changes: 57 additions & 13 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import io
import pickle
from collections import deque
from pathlib import Path

import pytest
Expand All @@ -11,10 +12,11 @@
from torch.utils.data.graph import traverse_dps
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets, transforms
from torchvision.prototype import datasets, features, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.features import Image, Label


assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
Expand All @@ -25,6 +27,17 @@ def extract_datapipes(dp):
return get_all_graph_pipes(traverse_dps(dp))


def consume(iterator):
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
deque(iterator, maxlen=0)


def next_consume(iterator):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful since we often just want the first sample, but need to make sure to still consume to avoid dangling streams. list(iterator) would also do the trick, but keeps everything in memory for no reason.

item = next(iterator)
consume(iterator)
return item


@pytest.fixture(autouse=True)
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
Expand Down Expand Up @@ -66,7 +79,7 @@ def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

try:
sample = next(iter(dataset))
sample = next_consume(iter(dataset))
except StopIteration:
raise AssertionError("Unable to draw any sample.") from None
except Exception as error:
Expand All @@ -84,22 +97,53 @@ def test_num_samples(self, dataset_mock, config):

assert len(list(dataset)) == mock_info["num_samples"]

@pytest.fixture
def log_session_streams(self):
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams
try:
StreamWrapper.debug_unclosed_streams = True
yield
finally:
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
def test_stream_closing(self, log_session_streams, dataset_mock, config):
def make_msg_and_close(head):
unclosed_streams = []
for stream in StreamWrapper.session_streams.keys():
unclosed_streams.append(repr(stream.file_obj))
stream.close()
unclosed_streams = "\n".join(unclosed_streams)
return f"{head}\n\n{unclosed_streams}"

if StreamWrapper.session_streams:
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:"))

dataset, _ = dataset_mock.load(config)

vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
consume(iter(dataset))

if StreamWrapper.session_streams:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a drive-by since I was already touching the line: the term "vanilla" tensor is no longer used. In the prototype transforms we use "simple tensor" now and also have features.is_simple_tensor to check for them.

if simple_tensors:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

next(iter(dataset.map(transforms.Identity())))
dataset = dataset.map(transforms.Identity())

consume(iter(dataset))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_traversable(self, dataset_mock, config):
Expand Down Expand Up @@ -131,7 +175,7 @@ def test_data_loader(self, dataset_mock, config, num_workers):
collate_fn=self._collate_fn,
)

next(iter(dl))
consume(dl)

# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
Expand All @@ -148,7 +192,7 @@ def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def test_save_load(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

sample = next(iter(dataset))
sample = next_consume(iter(dataset))

with io.BytesIO() as buffer:
torch.save(sample, buffer)
Expand Down Expand Up @@ -177,7 +221,7 @@ class TestQMNIST:
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

sample = next(iter(dataset))
sample = next_consume(iter(dataset))
for key, type in (
("nist_hsf_series", int),
("nist_writer_id", int),
Expand Down Expand Up @@ -214,7 +258,7 @@ def test_sample_content(self, dataset_mock, config):
assert "image" in sample
assert "label" in sample

assert isinstance(sample["image"], Image)
assert isinstance(sample["label"], Label)
assert isinstance(sample["image"], features.Image)
assert isinstance(sample["label"], features.Label)

assert sample["image"].shape == (1, 16, 16)
10 changes: 6 additions & 4 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,26 @@ def __init__(

def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)
lines = (line.decode() for line in file)

if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(file)
next(lines)

# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")

for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line

file.close()


NAME = "celeba"

Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _resources(self) -> List[OnlineResource]:

def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
content = cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
file.close()
return content

def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
for _, file in scenes_dp:
file.close()
dp = Mapper(images_dp, self._add_empty_anns)

return Mapper(dp, self._prepare_sample)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape)

file.close()


class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]]
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/pcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
data = data[self.key]
yield from data

handle.close()


_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))

Expand Down
68 changes: 39 additions & 29 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,35 @@ def __init__(
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)

def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
)
extra_split = HttpResource(
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
)
return [archive, extra_split]
resources = [
HttpResource(
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
)
]
if self._split == "train_noval":
resources.append(
HttpResource(
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
)
)
return resources # type: ignore[return-value]

def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
parent, grandparent, *_ = path.parents

if parent.name == "dataset":
return 0
elif grandparent.name == "dataset":
if grandparent.name == "dataset":
if parent.name == "img":
return 1
return 0
elif parent.name == "cls":
return 2
else:
return None
else:
return None
return 1

if parent.name == "dataset" and self._split != "train_noval":
return 2

return None

def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
split_and_image_data, ann_data = data
Expand All @@ -93,18 +97,24 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp, extra_split_dp = resource_dps

archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
if self._split == "train_noval":
split_dp = extra_split_dp
archive_dp, split_dp = resource_dps
images_dp, anns_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
else:
archive_dp = resource_dps[0]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen with resource_dps[1] in this case? It is disconnected from the graph or remaining unconsumed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If if self._split != "train_noval", we have only one element in resource_dps. Meaning, it will not be loaded at all and thus also does not need to be closed.

images_dp, anns_dp, split_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)

split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
return None

def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
buffer.close()
return ann

def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer)
Expand Down
8 changes: 3 additions & 5 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision.prototype.utils._internal import fromfile


Expand Down Expand Up @@ -40,10 +39,9 @@ def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error

if isinstance(buffer, StreamWrapper):
buffer = buffer.file_obj

return sio.loadmat(buffer, **kwargs)
data = sio.loadmat(buffer, **kwargs)
buffer.close()
return data


class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
Expand Down
Loading