Skip to content

Closing streams to avoid testing issues #6128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
19 changes: 15 additions & 4 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
Expand Down Expand Up @@ -64,9 +65,9 @@ def test_smoke(self, dataset_mock, config):
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

try:
sample = next(iter(dataset))
iterator = iter(dataset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of sticking with the iterator pattern here, can't we just simply do

samples = list(dataset)

if not samples:
    raise AssertionError(...)

sample = samples[0]

...

sample = next(iterator)
except StopIteration:
raise AssertionError("Unable to draw any sample.") from None
except Exception as error:
Expand All @@ -78,23 +79,33 @@ def test_sample(self, dataset_mock, config):
if not sample:
raise AssertionError("Sample dictionary is empty.")

list(iterator) # Cleanups and closing streams in buffers

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config)

assert len(list(dataset)) == mock_info["num_samples"]

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
StreamWrapper.session_streams = {}
dataset, _ = dataset_mock.load(config)

vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
iterator = iter(dataset)
one_element = next(iterator)
Comment on lines +94 to +95
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.


vanilla_tensors = {key for key, value in one_element.items() if type(value) is torch.Tensor}
if vanilla_tensors:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)

list(iterator) # Cleanups and closing streams in buffers

if len(StreamWrapper.session_streams) > 0:
raise Exception(StreamWrapper.session_streams)
Comment on lines +106 to +107
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what this does? Is StreamWrapper.session_streams just a counter for open streams? If yes, why are we only testing this here and not in the other tests? If this is something we should check in general, we can use a decorator like

def check_unclosed_streams(test_fn):
    @functools.wraps(test_fn)
    def wrapper(*args, **kwargs):
        if len(StreamWrapper.session_streams) > 0:
            raise pytest.UsageError("Some previous test didn't clean up")
        
        test_fn(*args, **kwargs)
        
        if len(StreamWrapper.session_streams) > 0:
            raise Assertion("This test didn't clean up")
        
    return wrapper


@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
Expand Down
7 changes: 5 additions & 2 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def _prepare_sample(
ann_path, ann_buffer = ann_data

image = EncodedImage.from_file(image_buffer)
image_buffer.close()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The errors we have seen in our test suite have never been with these files, but only with archives.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests complain that archive stream is not closed. This is because child (unpacked file stream) also remains open and referencing parent. In pytorch/pytorch#78952 and pytorch/data#560 we introduced mechanism to close parent steams when every child is closed.

ann = read_mat(ann_buffer)
ann_buffer.close()
Comment on lines +105 to +107
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing that in every dataset individually, can't we just do it in

and

def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:

? I think so far we don't have a case where we need to read from the same file handle twice. Plus, that would only work if the stream is seekable, which I don't know if we can guarantee.


return dict(
label=Label.from_category(category, categories=self._categories),
Expand Down Expand Up @@ -181,10 +183,11 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:

def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data

image = EncodedImage.from_file(buffer)
buffer.close()
return dict(
path=path,
image=EncodedImage.from_file(buffer),
image=image,
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
)

Expand Down
7 changes: 5 additions & 2 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
self.fieldnames = fieldnames

def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)
for _, fh in self.datapipe:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with the closing here, but why the rename? Can you revert that?

file = (line.decode() for line in fh)

if self.fieldnames:
fieldnames = self.fieldnames
Expand All @@ -48,6 +48,8 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line

fh.close()


NAME = "celeba"

Expand Down Expand Up @@ -132,6 +134,7 @@ def _prepare_sample(
path, buffer = image_data

image = EncodedImage.from_file(buffer)
buffer.close()
(_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data

return dict(
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _resources(self) -> List[OnlineResource]:

def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
result = pickle.load(file, encoding="latin1")
file.close()
return cast(Dict[str, Any], result)

def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
Expand Down
7 changes: 6 additions & 1 deletion torchvision/prototype/datasets/_builtin/clevr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union

from torchdata import janitor
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
Expand Down Expand Up @@ -62,10 +63,12 @@ def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, Binary
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]:
image_data, scenes_data = data
path, buffer = image_data
image = EncodedImage.from_file(buffer)
buffer.close()

return dict(
path=path,
image=EncodedImage.from_file(buffer),
image=image,
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)

Expand Down Expand Up @@ -97,6 +100,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
for i in scenes_dp:
janitor(i)
Comment on lines +103 to +104
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can we make the loop variable more expressive?
  2. Can we use torchdata.janitor instead to make it more clear where this is coming from?
Suggested change
for i in scenes_dp:
janitor(i)
for _, file in scenes_dp:
janitor(file)

Plus, do we even need to use torchdata.janitor here? Would just .close() by sufficient?

Suggested change
for i in scenes_dp:
janitor(i)
for _, file in scenes_dp:
file.close()

dp = Mapper(images_dp, self._add_empty_anns)

return Mapper(dp, self._prepare_sample)
Expand Down
6 changes: 5 additions & 1 deletion torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
close_buffer,
getitem,
hint_sharding,
hint_shuffling,
Expand Down Expand Up @@ -169,9 +170,10 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:

def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data
image = close_buffer(EncodedImage.from_file, buffer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If EncodedImage.from_file closes automatically we also don't need this wrapper.

return dict(
path=path,
image=EncodedImage.from_file(buffer),
image=image,
)

def _prepare_sample(
Expand All @@ -182,9 +184,11 @@ def _prepare_sample(
anns, image_meta = ann_data

sample = self._prepare_image(image_data)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you revert the formatting changes?

# this method is only called if we have annotations
annotations = cast(str, self._annotations)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))

return sample

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/country211.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ def _resources(self) -> List[OnlineResource]:

def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
image = EncodedImage.from_file(buffer)
buffer.close()
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
image=image,
)

def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def _2011_prepare_ann(
) -> Dict[str, Any]:
_, (bounding_box_data, segmentation_data) = data
segmentation_path, segmentation_buffer = segmentation_data
segmentation = EncodedImage.from_file(segmentation_buffer)
segmentation_buffer.close()
return dict(
bounding_box=BoundingBox(
[float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size
),
segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer),
segmentation=segmentation,
)

def _2010_split_key(self, data: str) -> str:
Expand All @@ -152,6 +154,7 @@ def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, Bi
def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]:
_, (path, buffer) = data
content = read_mat(buffer)
buffer.close()
return dict(
ann_path=path,
bounding_box=BoundingBox(
Expand All @@ -173,6 +176,7 @@ def _prepare_sample(
path, buffer = image_data

image = EncodedImage.from_file(buffer)
buffer.close()

return dict(
prepare_ann_fn(anns_data, image.image_size),
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,16 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO
(_, joint_categories_data), image_data = data
_, *joint_categories = joint_categories_data
path, buffer = image_data
image = EncodedImage.from_file(buffer)
buffer.close()

category = pathlib.Path(path).parent.name

return dict(
joint_categories={category for category in joint_categories if category},
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
image=image,
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def _resources(self) -> List[OnlineResource]:
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
image = EncodedImage.from_file(buffer)
buffer.close()
return dict(
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
image=image,
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/food101.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:

def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
id, (path, buffer) = data
image = EncodedImage.from_file(buffer)
buffer.close()
return dict(
label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
image=image,
)

def _image_key(self, data: Tuple[str, Any]) -> str:
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/datasets/_builtin/gtsrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[
format="xyxy",
image_size=(int(csv_info["Height"]), int(csv_info["Width"])),
)
image = EncodedImage.from_file(buffer)
buffer.close()

return {
"path": path,
"image": EncodedImage.from_file(buffer),
"image": image,
"label": Label(label, categories=self._categories),
"bounding_box": bounding_box,
}
Expand Down
7 changes: 5 additions & 2 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[st
return None, data

def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
name, binary_io = data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you revert this, since binary_io doesn't seem to be used.

return {
"meta.mat": ImageNetDemux.META,
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
}.get(pathlib.Path(data[0]).name)
}.get(pathlib.Path(name).name)

_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")

Expand All @@ -151,11 +152,13 @@ def _prepare_sample(
data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
label_data, (path, buffer) = data
image = EncodedImage.from_file(buffer)
buffer.close()

return dict(
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
path=path,
image=EncodedImage.from_file(buffer),
image=image,
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape)

file.close()


class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]]
Expand Down
8 changes: 6 additions & 2 deletions torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,18 @@ def _prepare_sample(
classification_data, segmentation_data = ann_data
segmentation_path, segmentation_buffer = segmentation_data
image_path, image_buffer = image_data
segmentation = EncodedImage.from_file(segmentation_buffer)
segmentation_buffer.close()
image = EncodedImage.from_file(image_buffer)
image_buffer.close()

return dict(
label=Label(int(classification_data["label"]) - 1, categories=self._categories),
species="cat" if classification_data["species"] == "1" else "dog",
segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer),
segmentation=segmentation,
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
image=image,
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/pcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
if self.key is not None:
data = data[self.key]
yield from data
handle.close()


_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
Expand Down
10 changes: 8 additions & 2 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

NAME = "sbd"

from torchdata import janitor


@register_info(NAME)
def _info() -> Dict[str, Any]:
Expand Down Expand Up @@ -82,10 +84,12 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
ann_path, ann_buffer = ann_data

anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"]

ann_buffer.close()
image = EncodedImage.from_file(image_buffer)
image_buffer.close()
return dict(
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
image=image,
ann_path=ann_path,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
Expand All @@ -104,6 +108,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
drop_none=True,
)
if self._split == "train_noval":
for i in split_dp:
janitor(i)
Comment on lines +111 to +112
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, don't we need to do the same on extra_split_dp in the else branch?

split_dp = extra_split_dp

split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
Expand Down
Loading