Skip to content

Commit afb0ec2

Browse files
committed
refactor prototype SBD to avoid closing demux streams at construction time
1 parent 6e6c31e commit afb0ec2

File tree

2 files changed

+55
-46
lines changed

2 files changed

+55
-46
lines changed

test/builtin_dataset_mocks.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -661,15 +661,15 @@ class SBDMockData:
661661
_NUM_CATEGORIES = 20
662662

663663
@classmethod
664-
def _make_split_files(cls, root_map):
665-
ids_map = {
666-
split: [f"2008_{idx:06d}" for idx in idcs]
667-
for split, idcs in (
668-
("train", [0, 1, 2]),
669-
("train_noval", [0, 2]),
670-
("val", [3]),
671-
)
672-
}
664+
def _make_split_files(cls, root_map, *, split):
665+
splits_and_idcs = [
666+
("train", [0, 1, 2]),
667+
("val", [3]),
668+
]
669+
if split == "train_noval":
670+
splits_and_idcs.append(("train_noval", [0, 2]))
671+
672+
ids_map = {split: [f"2008_{idx:06d}" for idx in idcs] for split, idcs in splits_and_idcs}
673673

674674
for split, ids in ids_map.items():
675675
with open(root_map[split] / f"{split}.txt", "w") as fh:
@@ -710,25 +710,27 @@ def _make_segmentation(cls, size):
710710
return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy()
711711

712712
@classmethod
713-
def generate(cls, root):
713+
def generate(cls, root, *, split):
714714
archive_folder = root / "benchmark_RELEASE"
715715
dataset_folder = archive_folder / "dataset"
716716
dataset_folder.mkdir(parents=True, exist_ok=True)
717717

718-
ids, num_samples_map = cls._make_split_files(defaultdict(lambda: dataset_folder, {"train_noval": root}))
718+
ids, num_samples_map = cls._make_split_files(
719+
defaultdict(lambda: dataset_folder, {"train_noval": root}), split=split
720+
)
719721
sizes = cls._make_anns_folder(dataset_folder, "cls", ids)
720722
create_image_folder(
721723
dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx]
722724
)
723725

724726
make_tar(root, "benchmark.tgz", archive_folder, compression="gz")
725727

726-
return num_samples_map
728+
return num_samples_map[split]
727729

728730

729731
@register_mock(configs=combinations_grid(split=("train", "val", "train_noval")))
730732
def sbd(root, config):
731-
return SBDMockData.generate(root)[config["split"]]
733+
return SBDMockData.generate(root, split=config["split"])
732734

733735

734736
@register_mock(configs=[dict()])

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,31 +49,35 @@ def __init__(
4949
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)
5050

5151
def _resources(self) -> List[OnlineResource]:
52-
archive = HttpResource(
53-
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
54-
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
55-
)
56-
extra_split = HttpResource(
57-
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
58-
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
59-
)
60-
return [archive, extra_split]
52+
resources = [
53+
HttpResource(
54+
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
55+
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
56+
)
57+
]
58+
if self._split == "train_noval":
59+
resources.append(
60+
HttpResource(
61+
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
62+
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
63+
)
64+
)
65+
return resources
6166

6267
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
6368
path = pathlib.Path(data[0])
6469
parent, grandparent, *_ = path.parents
6570

66-
if parent.name == "dataset":
67-
return 0
68-
elif grandparent.name == "dataset":
71+
if grandparent.name == "dataset":
6972
if parent.name == "img":
70-
return 1
73+
return 0
7174
elif parent.name == "cls":
72-
return 2
73-
else:
74-
return None
75-
else:
76-
return None
75+
return 1
76+
77+
if parent.name == "dataset" and self._split != "train_noval":
78+
return 2
79+
80+
return None
7781

7882
def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
7983
split_and_image_data, ann_data = data
@@ -93,21 +97,24 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
9397
)
9498

9599
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
96-
archive_dp, extra_split_dp = resource_dps
97-
98-
archive_dp = resource_dps[0]
99-
split_dp, images_dp, anns_dp = Demultiplexer(
100-
archive_dp,
101-
3,
102-
self._classify_archive,
103-
buffer_size=INFINITE_BUFFER_SIZE,
104-
drop_none=True,
105-
)
106-
split_dp, to_be_closed_dp = (
107-
(extra_split_dp, split_dp) if self._split == "train_noval" else (split_dp, extra_split_dp)
108-
)
109-
for _, file in to_be_closed_dp:
110-
file.close()
100+
if self._split == "train_noval":
101+
archive_dp, split_dp = resource_dps
102+
images_dp, anns_dp = Demultiplexer(
103+
archive_dp,
104+
2,
105+
self._classify_archive,
106+
buffer_size=INFINITE_BUFFER_SIZE,
107+
drop_none=True,
108+
)
109+
else:
110+
archive_dp = resource_dps[0]
111+
images_dp, anns_dp, split_dp = Demultiplexer(
112+
archive_dp,
113+
3,
114+
self._classify_archive,
115+
buffer_size=INFINITE_BUFFER_SIZE,
116+
drop_none=True,
117+
)
111118

112119
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
113120
split_dp = LineReader(split_dp, decode=True)

0 commit comments

Comments
 (0)