Skip to content

close streams in prototype datasets #6647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 57 additions & 13 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import io
import pickle
from collections import deque
from pathlib import Path

import pytest
Expand All @@ -11,10 +12,11 @@
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets, transforms
from torchvision.prototype import datasets, features, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.features import Image, Label


assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
Expand All @@ -25,6 +27,17 @@ def extract_datapipes(dp):
return get_all_graph_pipes(traverse(dp, only_datapipe=True))


def consume(iterator):
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
deque(iterator, maxlen=0)


def next_consume(iterator):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful since we often just want the first sample, but need to make sure to still consume to avoid dangling streams. list(iterator) would also do the trick, but keeps everything in memory for no reason.

item = next(iterator)
consume(iterator)
return item


@pytest.fixture(autouse=True)
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
Expand Down Expand Up @@ -66,7 +79,7 @@ def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

try:
sample = next(iter(dataset))
sample = next_consume(iter(dataset))
except StopIteration:
raise AssertionError("Unable to draw any sample.") from None
except Exception as error:
Expand All @@ -84,22 +97,53 @@ def test_num_samples(self, dataset_mock, config):

assert len(list(dataset)) == mock_info["num_samples"]

@pytest.fixture
def log_session_streams(self):
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams
try:
StreamWrapper.debug_unclosed_streams = True
yield
finally:
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
def test_stream_closing(self, log_session_streams, dataset_mock, config):
def make_msg_and_close(head):
unclosed_streams = []
for stream in StreamWrapper.session_streams.keys():
unclosed_streams.append(repr(stream.file_obj))
stream.close()
unclosed_streams = "\n".join(unclosed_streams)
return f"{head}\n\n{unclosed_streams}"

if StreamWrapper.session_streams:
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:"))

dataset, _ = dataset_mock.load(config)

vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
consume(iter(dataset))

if StreamWrapper.session_streams:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a drive-by since I was already touching the line: the term "vanilla" tensor is no longer used. In the prototype transforms we use "simple tensor" now and also have features.is_simple_tensor to check for them.

if simple_tensors:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

next(iter(dataset.map(transforms.Identity())))
dataset = dataset.map(transforms.Identity())

consume(iter(dataset))

@pytest.mark.parametrize("only_datapipe", [False, True])
@parametrize_dataset_mocks(DATASET_MOCKS)
Expand Down Expand Up @@ -132,7 +176,7 @@ def test_data_loader(self, dataset_mock, config, num_workers):
collate_fn=self._collate_fn,
)

next(iter(dl))
consume(dl)

# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
Expand All @@ -149,7 +193,7 @@ def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def test_save_load(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

sample = next(iter(dataset))
sample = next_consume(iter(dataset))

with io.BytesIO() as buffer:
torch.save(sample, buffer)
Expand Down Expand Up @@ -178,7 +222,7 @@ class TestQMNIST:
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

sample = next(iter(dataset))
sample = next_consume(iter(dataset))
for key, type in (
("nist_hsf_series", int),
("nist_writer_id", int),
Expand Down Expand Up @@ -215,7 +259,7 @@ def test_sample_content(self, dataset_mock, config):
assert "image" in sample
assert "label" in sample

assert isinstance(sample["image"], Image)
assert isinstance(sample["label"], Label)
assert isinstance(sample["image"], features.Image)
assert isinstance(sample["label"], features.Label)

assert sample["image"].shape == (1, 16, 16)
10 changes: 6 additions & 4 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,26 @@ def __init__(

def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)
lines = (line.decode() for line in file)

if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(file)
next(lines)

# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")

for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line

file.close()


NAME = "celeba"

Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _resources(self) -> List[OnlineResource]:

def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
content = cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
file.close()
return content

def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
for _, file in scenes_dp:
file.close()
dp = Mapper(images_dp, self._add_empty_anns)

return Mapper(dp, self._prepare_sample)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape)

file.close()


class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]]
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/pcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
data = data[self.key]
yield from data

handle.close()


_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))

Expand Down
7 changes: 5 additions & 2 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
if self._split == "train_noval":
split_dp = extra_split_dp
split_dp, to_be_closed_dp = (
(extra_split_dp, split_dp) if self._split == "train_noval" else (split_dp, extra_split_dp)
)
for _, file in to_be_closed_dp:
file.close()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat problematic in the split == "train_noval" case. We want to close all file handles that are coming from the split_dp. Unfortunately, split_dp comes from a Demultiplexer. Thus, by fully iterating over it here, we are loading everything into the demux buffer.

Is there an idiom to "mark" a datapipe to be closed at runtime even if we don't return the datapipe? The only thing I came up with is changing the classifier function of the Demultiplexer to drop the samples that would go into split_dp if split == "train_noval".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is incorrect to consume Datapipe in construction time.

        for _, file in to_be_closed_dp:
            file.close()

It will not affect executed graph.

Ideally, we want to have something like dp.close(), which will effectively remove dangling pieces of the graph.

But for now you can either use trick like split_dp = split_dp.concatinate(to_be_closed_dp.filter(lambda x: False)) or do code branching before Demux and avoid creating dangling pieces.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is incorrect to consume Datapipe in construction time.

I don't think that is fully correct for our case. At construction, we load a couple of files and weave them together into one datapipe. These files are always loaded unconditionally, but for some configurations not all of the files are needed. So we should be able to simply close them during the construction of the dataset datapipe, since they will never make it in the graph, correct?

I agree, we shouldn't do this for datapipes that stem for a Demultiplexer if the other parts make it into the final graph.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do code branching before Demux and avoid creating dangling pieces.

If possible, I think that is the better solution, since the other still iterates over all items when the loop should actually be done. I guess that could be irritating as well.

I implemented branching in afb0ec2. PTAL


split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
return None

def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
buffer.close()
return ann

def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer)
Expand Down
8 changes: 3 additions & 5 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision.prototype.utils._internal import fromfile


Expand Down Expand Up @@ -40,10 +39,9 @@ def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error

if isinstance(buffer, StreamWrapper):
buffer = buffer.file_obj

return sio.loadmat(buffer, **kwargs)
data = sio.loadmat(buffer, **kwargs)
buffer.close()
return data


class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __new__(

@classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
file.close()
return encoded_data

@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
Expand Down