diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 0ee2eb9f8e1..19f4c3d4c72 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -1,78 +1,57 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, Type, TypeVar, Union +from typing import Any, Optional, Union import torch -from torch.utils._pytree import tree_map from ._datapoint import Datapoint -L = TypeVar("L", bound="_LabelBase") - - -class _LabelBase(Datapoint): - categories: Optional[Sequence[str]] - +class Label(Datapoint): @classmethod - def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: - label_base = tensor.as_subclass(cls) - label_base.categories = categories - return label_base + def _wrap(cls, tensor: torch.Tensor) -> Label: + return tensor.as_subclass(cls) def __new__( - cls: Type[L], + cls, data: Any, *, - categories: Optional[Sequence[str]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, - ) -> L: + ) -> Label: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return cls._wrap(tensor, categories=categories) + return cls._wrap(tensor) @classmethod - def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L: - return cls._wrap( - tensor, - categories=categories if categories is not None else other.categories, - ) - - @classmethod - def from_category( - cls: Type[L], - category: str, - *, - categories: Sequence[str], - **kwargs: Any, - ) -> L: - return cls(categories.index(category), categories=categories, **kwargs) - - -class Label(_LabelBase): - def to_categories(self) -> Any: - if self.categories is None: - raise RuntimeError("Label does not have categories") + def wrap_like( + cls, + other: Label, + tensor: torch.Tensor, + ) -> Label: + return cls._wrap(tensor) - return tree_map(lambda idx: self.categories[idx], self.tolist()) +class OneHotLabel(Datapoint): + @classmethod + def _wrap(cls, tensor: torch.Tensor) -> OneHotLabel: + return tensor.as_subclass(cls) -class OneHotLabel(_LabelBase): def __new__( cls, data: Any, *, - categories: Optional[Sequence[str]] = None, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, + requires_grad: Optional[bool] = None, ) -> OneHotLabel: - one_hot_label = super().__new__( - cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad - ) - - if categories is not None and len(categories) != one_hot_label.shape[-1]: - raise ValueError() + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + return cls._wrap(tensor) - return one_hot_label + @classmethod + def wrap_like( + cls, + other: OneHotLabel, + tensor: torch.Tensor, + ) -> OneHotLabel: + return cls._wrap(tensor) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 55a77c1a920..e597bdedb8d 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -4,9 +4,15 @@ import numpy as np from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints import BoundingBox from torchvision.prototype.datapoints._datapoint import Datapoint -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + GDriveResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -106,7 +112,7 @@ def _prepare_sample( ann = read_mat(ann_buffer) return dict( - label=Label.from_category(category, categories=self._categories), + label=LabelWithCategories.from_category(category, categories=self._categories), image_path=image_path, image=image, ann_path=ann_path, @@ -188,7 +194,9 @@ def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: return dict( path=path, image=EncodedImage.from_file(buffer), - label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories), + label=LabelWithCategories( + int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories + ), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 9050cf0b596..03d82d70a72 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -3,9 +3,15 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints import BoundingBox from torchvision.prototype.datapoints._datapoint import Datapoint -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + GDriveResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -141,7 +147,7 @@ def _prepare_sample( return dict( path=path, image=image, - identity=Label(int(identity["identity"])), + identity=LabelWithCategories(int(identity["identity"])), attributes={attr: value == "1" for attr, value in attributes.items()}, bounding_box=BoundingBox( [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index de87f46c8b1..cb038f0aee7 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -6,8 +6,8 @@ import numpy as np from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Image, Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datapoints import Image +from torchvision.prototype.datasets.utils import Dataset, HttpResource, LabelWithCategories, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -70,7 +70,7 @@ def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data return dict( image=Image(image_array), - label=Label(category_idx, categories=self._categories), + label=LabelWithCategories(category_idx, categories=self._categories), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index e282635684e..812929dcfd8 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -2,8 +2,13 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -66,7 +71,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A return dict( path=path, image=EncodedImage.from_file(buffer), - label=Label(len(scenes_data["objects"])) if scenes_data else None, + label=LabelWithCategories(len(scenes_data["objects"])) if scenes_data else None, ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index fa68bf4dc6f..fd808be46c7 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -14,9 +14,15 @@ Mapper, UnBatcher, ) -from torchvision.prototype.datapoints import BoundingBox, Label, Mask +from torchvision.prototype.datapoints import BoundingBox, Mask from torchvision.prototype.datapoints._datapoint import Datapoint -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -131,7 +137,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st format="xywh", spatial_size=spatial_size, ), - labels=Label(labels, categories=self._categories), + labels=LabelWithCategories(labels, categories=self._categories), super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], ann_ids=[ann["id"] for ann in anns], ) diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 0f4b3d769dc..f7ca571b573 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -2,8 +2,14 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource + +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -53,7 +59,7 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self._categories), + label=LabelWithCategories.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index ea192baf650..ee54edb30be 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -14,9 +14,15 @@ Mapper, ) from torchdata.datapipes.map import IterToMapConverter -from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints import BoundingBox from torchvision.prototype.datapoints._datapoint import Datapoint -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + GDriveResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -180,7 +186,7 @@ def _prepare_sample( return dict( prepare_ann_fn(anns_data, image.spatial_size), image=image, - label=Label( + label=LabelWithCategories( int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]) - 1, categories=self._categories, ), diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 6ddab2af79d..a85edd8fb9e 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -3,8 +3,14 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource + +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -89,7 +95,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO return dict( joint_categories={category for category in joint_categories if category}, - label=Label.from_category(category, categories=self._categories), + label=LabelWithCategories.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py index 463eed79d70..69c6fe04dfd 100644 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -2,8 +2,13 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from .._api import register_dataset, register_info @@ -51,7 +56,7 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self._categories), + label=LabelWithCategories.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index 73c6184b6e7..1fa07b87a60 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -3,8 +3,8 @@ import torch from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Image, Label -from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource +from torchvision.prototype.datapoints import Image +from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, LabelWithCategories, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from .._api import register_dataset, register_info @@ -49,7 +49,7 @@ def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: return dict( image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), - label=Label(int(label_id), categories=self._categories) if label_id is not None else None, + label=LabelWithCategories(int(label_id), categories=self._categories) if label_id is not None else None, ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index f3054d8fb13..364f50b3c5b 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -2,8 +2,14 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource + +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -57,7 +63,7 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]: id, (path, buffer) = data return dict( - label=Label.from_category(id.split("/", 1)[0], categories=self._categories), + label=LabelWithCategories.from_category(id.split("/", 1)[0], categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index adcc31b277a..fcac7d06505 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -2,8 +2,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import BoundingBox, Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource +from torchvision.prototype.datapoints import BoundingBox +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -84,7 +90,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ return { "path": path, "image": EncodedImage.from_file(buffer), - "label": Label(label, categories=self._categories), + "label": LabelWithCategories(label, categories=self._categories), "bounding_box": bounding_box, } diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 5e2db41e1d0..fbe3391fd7b 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -15,8 +15,13 @@ TarArchiveLoader, ) from torchdata.datapipes.map import IterToMapConverter -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, ManualDownloadResource, OnlineResource +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + LabelWithCategories, + ManualDownloadResource, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -118,10 +123,12 @@ def _resources(self) -> List[OnlineResource]: _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: + def _prepare_train_data( + self, data: Tuple[str, BinaryIO] + ) -> Tuple[Tuple[LabelWithCategories, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) + label = LabelWithCategories.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), data def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: @@ -140,15 +147,15 @@ def _val_test_image_key(self, path: pathlib.Path) -> int: def _prepare_val_data( self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] - ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: + ) -> Tuple[Tuple[LabelWithCategories, str], Tuple[str, BinaryIO]]: label_data, image_data = data _, wnid = label_data - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) + label = LabelWithCategories.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), image_data def _prepare_sample( self, - data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], + data: Tuple[Optional[Tuple[LabelWithCategories, str]], Tuple[str, BinaryIO]], ) -> Dict[str, Any]: label_data, (path, buffer) = data diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 9364aa3ade9..35e3ba071ac 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,8 +7,8 @@ import torch from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import Image, Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datapoints import Image +from torchvision.prototype.datasets.utils import Dataset, HttpResource, LabelWithCategories, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE from torchvision.prototype.utils._internal import fromfile @@ -95,7 +95,7 @@ def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, image, label = data return dict( image=Image(image), - label=Label(label, dtype=torch.int64, categories=self._categories), + label=LabelWithCategories(label, dtype=torch.int64, categories=self._categories), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index fbc7d30c292..b85387b4366 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -3,8 +3,14 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource + +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -78,7 +84,7 @@ def _prepare_sample( image_path, image_buffer = image_data return dict( - label=Label(int(classification_data["label"]) - 1, categories=self._categories), + label=LabelWithCategories(int(classification_data["label"]) - 1, categories=self._categories), species="cat" if classification_data["species"] == "1" else "dog", segmentation_path=segmentation_path, segmentation=EncodedImage.from_file(segmentation_buffer), diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 9de224b95f0..6f55f296787 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -4,8 +4,8 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import Image, Label -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource +from torchvision.prototype.datapoints import Image +from torchvision.prototype.datasets.utils import Dataset, GDriveResource, LabelWithCategories, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from .._api import register_dataset, register_info @@ -109,7 +109,7 @@ def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: return { "image": Image(image.transpose(2, 0, 1)), - "label": Label(target.item(), categories=self._categories), + "label": LabelWithCategories(target.item(), categories=self._categories), } def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 9ae2c17ab5d..bbe8d3efc48 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -3,8 +3,8 @@ import torch from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Image, OneHotLabel -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datapoints import Image +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OneHotLabelWithCategories, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from .._api import register_dataset, register_info @@ -40,7 +40,7 @@ def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: return dict( image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)), - label=OneHotLabel([int(label) for label in label_data], categories=self._categories), + label=OneHotLabelWithCategories([int(label) for label in label_data], categories=self._categories), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 02db37169c1..494c85a346a 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,8 +2,14 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datapoints import BoundingBox, Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource +from torchvision.prototype.datapoints import BoundingBox +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -88,7 +94,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, return dict( path=path, image=image, - label=Label(target[4] - 1, categories=self._categories), + label=LabelWithCategories(target[4] - 1, categories=self._categories), bounding_box=BoundingBox(target[:4], format="xyxy", spatial_size=image.spatial_size), ) diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index d276298ca02..0b8871fc2fd 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -3,8 +3,8 @@ import numpy as np from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher -from torchvision.prototype.datapoints import Image, Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datapoints import Image +from torchvision.prototype.datasets.utils import Dataset, HttpResource, LabelWithCategories, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat from .._api import register_dataset, register_info @@ -64,7 +64,7 @@ def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any] return dict( image=Image(image_array.transpose((2, 0, 1))), - label=Label(int(label_array) % 10, categories=self._categories), + label=LabelWithCategories(int(label_array) % 10, categories=self._categories), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 7d1fed04e07..aad45212526 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -3,8 +3,8 @@ import torch from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper -from torchvision.prototype.datapoints import Image, Label -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datapoints import Image +from torchvision.prototype.datasets.utils import Dataset, HttpResource, LabelWithCategories, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from .._api import register_dataset, register_info @@ -55,7 +55,7 @@ def _prepare_sample(self, line: str) -> Dict[str, Any]: pixels = torch.tensor(values).add_(1).div_(2) return dict( image=Image(pixels.reshape(16, 16)), - label=Label(int(label) - 1, categories=self._categories), + label=LabelWithCategories(int(label) - 1, categories=self._categories), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d14189132be..3fc3ff5e91c 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -6,8 +6,14 @@ from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.datasets import VOCDetection -from torchvision.prototype.datapoints import BoundingBox, Label -from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource +from torchvision.prototype.datapoints import BoundingBox +from torchvision.prototype.datasets.utils import ( + Dataset, + EncodedImage, + HttpResource, + LabelWithCategories, + OnlineResource, +) from torchvision.prototype.datasets.utils._internal import ( getitem, hint_sharding, @@ -110,7 +116,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: format="xyxy", spatial_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), ), - labels=Label( + labels=LabelWithCategories( [self._categories.index(instance["name"]) for instance in instances], categories=self._categories ), ) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 0a37df03add..a0eb5150416 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -5,8 +5,7 @@ from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper -from torchvision.prototype.datapoints import Label -from torchvision.prototype.datasets.utils import EncodedData, EncodedImage +from torchvision.prototype.datasets.utils import EncodedData, EncodedImage, LabelWithCategories from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling @@ -29,7 +28,7 @@ def _prepare_sample( return dict( path=path, data=EncodedData.from_file(buffer), - label=Label.from_category(category, categories=categories), + label=LabelWithCategories.from_category(category, categories=categories), ) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 3fdb53eec43..908ab302b47 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip +from ._datapoints import EncodedData, EncodedImage, LabelWithCategories, OneHotLabelWithCategories from ._dataset import Dataset -from ._encoded import EncodedData, EncodedImage from ._resource import GDriveResource, HttpResource, KaggleDownloadResource, ManualDownloadResource, OnlineResource diff --git a/torchvision/prototype/datasets/utils/_datapoints.py b/torchvision/prototype/datasets/utils/_datapoints.py new file mode 100644 index 00000000000..53802e2e6e0 --- /dev/null +++ b/torchvision/prototype/datasets/utils/_datapoints.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import os +import sys + +from typing import Any, BinaryIO, Optional, Sequence, Tuple, Type, TypeVar, Union + +import PIL.Image + +import torch + +from torch.utils._pytree import tree_map +from torchvision.prototype import datapoints + +from torchvision.prototype.datapoints._datapoint import Datapoint +from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer + +D = TypeVar("D", bound="EncodedData") + + +class EncodedData(Datapoint): + @classmethod + def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: + return tensor.as_subclass(cls) + + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: bool = False, + ) -> EncodedData: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? + return cls._wrap(tensor) + + @classmethod + def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: + return cls._wrap(tensor) + + @classmethod + def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D: + encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs) + file.close() + return encoded_data + + @classmethod + def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D: + with open(path, "rb") as file: + return cls.from_file(file, **kwargs) + + +class EncodedImage(EncodedData): + # TODO: Use @functools.cached_property if we can depend on Python 3.8 + @property + def spatial_size(self) -> Tuple[int, int]: + if not hasattr(self, "_spatial_size"): + with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image: + self._spatial_size = image.height, image.width + + return self._spatial_size + + +L = TypeVar("L", bound="_LabelWithCategoriesBase") + + +class _LabelWithCategoriesBase(datapoints.Label): + categories: Optional[Sequence[str]] + + @classmethod + def _wrap( # type: ignore[override] + cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]] + ) -> L: + label_base = tensor.as_subclass(cls) + label_base.categories = categories + return label_base + + def __new__( + cls: Type[L], + data: Any, + *, + categories: Optional[Sequence[str]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ) -> L: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + return cls._wrap(tensor, categories=categories) + + @classmethod + def wrap_like( # type: ignore[override] + cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None + ) -> L: + return cls._wrap( + tensor, + categories=categories if categories is not None else other.categories, + ) + + @classmethod + def from_category( + cls: Type[L], + category: str, + *, + categories: Sequence[str], + **kwargs: Any, + ) -> L: + return cls(categories.index(category), categories=categories, **kwargs) + + +class LabelWithCategories(_LabelWithCategoriesBase): + def to_categories(self) -> Any: + if self.categories is None: + raise RuntimeError("Label does not have categories") + + return tree_map(lambda idx: self.categories[idx], self.tolist()) + + +class OneHotLabelWithCategories(_LabelWithCategoriesBase): + def __new__( + cls, + data: Any, + *, + categories: Optional[Sequence[str]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: bool = False, + ) -> OneHotLabelWithCategories: + one_hot_label = super().__new__( + cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad + ) + + if categories is not None and len(categories) != one_hot_label.shape[-1]: + raise ValueError() + + return one_hot_label diff --git a/torchvision/prototype/datasets/utils/_encoded.py b/torchvision/prototype/datasets/utils/_encoded.py deleted file mode 100644 index 64cd9f7b951..00000000000 --- a/torchvision/prototype/datasets/utils/_encoded.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import os -import sys -from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union - -import PIL.Image -import torch - -from torchvision.prototype.datapoints._datapoint import Datapoint -from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer - -D = TypeVar("D", bound="EncodedData") - - -class EncodedData(Datapoint): - @classmethod - def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - return tensor.as_subclass(cls) - - def __new__( - cls, - data: Any, - *, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str, int]] = None, - requires_grad: bool = False, - ) -> EncodedData: - tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? - return cls._wrap(tensor) - - @classmethod - def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - return cls._wrap(tensor) - - @classmethod - def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D: - encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs) - file.close() - return encoded_data - - @classmethod - def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D: - with open(path, "rb") as file: - return cls.from_file(file, **kwargs) - - -class EncodedImage(EncodedData): - # TODO: Use @functools.cached_property if we can depend on Python 3.8 - @property - def spatial_size(self) -> Tuple[int, int]: - if not hasattr(self, "_spatial_size"): - with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image: - self._spatial_size = image.height, image.width - - return self._spatial_size diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 70ae972d9e2..b54e052c257 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -167,16 +167,16 @@ class FiveCrop(Transform): """ Example: >>> class BatchMultiCrop(transforms.Transform): - ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.Label]): + ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.LabelWithCategories]): ... images_or_videos, labels = sample ... batch_size = len(images_or_videos) ... image_or_video = images_or_videos[0] ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) - ... labels = datapoints.Label.wrap_like(labels, labels.repeat(batch_size)) + ... labels = datapoints.LabelWithCategories.wrap_like(labels, labels.repeat(batch_size)) ... return images_or_videos, labels ... >>> image = datapoints.Image(torch.rand(3, 256, 256)) - >>> label = datapoints.Label(0) + >>> label = datapoints.LabelWithCategories(0) >>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()]) >>> images, labels = transform(image, label) >>> images.shape @@ -847,7 +847,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: ): raise TypeError( f"If a BoundingBox is contained in the input sample, " - f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." + f"{type(self).__name__}() also requires it to contain a LabelWithCategories or OneHotLabelWithCategories." ) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index c84aee62afe..fedd07110d8 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -21,10 +21,8 @@ def __init__(self, num_categories: int = -1): def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel: num_categories = self.num_categories - if num_categories == -1 and inpt.categories is not None: - num_categories = len(inpt.categories) output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) - return datapoints.OneHotLabel(output, categories=inpt.categories) + return datapoints.OneHotLabel(output) def extra_repr(self) -> str: if self.num_categories == -1: