diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 88f3277de58..eff273671c4 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -132,6 +132,9 @@ def _filter_images(self, data: Tuple[str, Any]) -> bool: return self._classify_archive(data) == DTDDemux.IMAGES def _generate_categories(self, root: pathlib.Path) -> List[str]: - dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) + resources = self.resources(self.default_config) + + dp = resources[0].load(root) dp = Filter(dp, self._filter_images) + return sorted({pathlib.Path(path).parent.name for path, _ in dp}) diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 99b1f643b61..5e26f3ebd03 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -145,10 +145,13 @@ def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: def _generate_categories(self, root: pathlib.Path) -> List[str]: config = self.default_config - dp = self.resources(config)[1].load(pathlib.Path(root) / self.name) + resources = self.resources(config) + + dp = resources[1].load(root) dp = Filter(dp, self._filter_split_and_classification_anns) dp = Filter(dp, path_comparator("name", f"{config.split}.txt")) dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") + raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp} raw_categories, _ = zip( *sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1]))