Description
This is a proposal. I'm not sure yet it's the best way to achieve this, so I'm putting this up for discussion.
tl;dr
Have specific tensor subclasses for BoxList
/ SegmentationMask
, CombinedObjects
, etc, which inherits from torch.Tensor
and overrides methods to properly dispatch to the relevant implementations. Depends on __torch_function__
from pytorch/pytorch#22402 (implemented in pytorch/pytorch#27064)
Background
For more than 2 years now, users have asked for ways of performing transformations of multiple inputs at the same time, for example for semantic segmentation or object detection #9
The recommended solution is to use the functional transforms in this case #230 #1169 , but for simple cases, this is a bit verbose.
Requirements
Ideally, we would want the following to be possible:
- work with a
Compose
style interface for simple cases - support more than a single input of each type (for example, two images and one segmentation mask)
- support joint rotations / rescale with different hyperparameters for different input types (images can do bilinear interpolation, segmentation maps should do nearest interpolation)
- be simple and modular
Proposed solution
We define new classes for each type of object, which should all inherit from torch.Tensor
, and implement / override a few specific methods. It might depend on __torch_function__
from pytorch/pytorch#22402 (implemented in pytorch/pytorch#27064)
Work with Compose
-style
We propose to define a CombinedObjects
(better names welcome), which is a collection of arbitrary objects (potentially named, but that's not a requirement).
Calling any of the methods in it should dispatch to the corresponding methods of its constituents. A basic example is below, I'll mention a few more points about it afterwards):
class CombinedObjects(object):
def __init__(self, **kwargs):
self.kwargs = kwargs
def hflip(self):
result = {}
for name, value in self.kwargs.items():
result[name] = value.hflip()
return type(self)(**result)
In this way, if the underlying objects follows the same protocol (i.e., implement the required functions), then this should allow to combine an arbitrary number of objects with a Compose
API via
# example for flip
class RandomFlip(object):
def __call__(self, x):
# implementation stays almost the same
# but now, `x` can be an Image, or a CombinedObject
if random.random() > 0.5:
x = x.hflip()
return x
transforms = Compose([
Resize(300),
RandomFlip(),
RandomColorAugment()
])
inputs = CombinedObjects(img1= x, img2=y, mask=z)
output = transforms(inputs)
which satisfies point 1 and 2 above, and part of point 3 (except for the different transformation hyperparameters for image / segmentation mask, which I'll cover next).
Different behavior for mask / boxes / images
In the same vein as the CombinedObject
approach from the previous section, we would have subclasses of torch.Tensor
for BoxList
/ SegmentationMask
/ etc which would override the behavior of specific functions so that they work as expected.
For example (and using code snippets from pytorch/pytorch#25629), we can define a class for segmentation masks where rotation / interpolation / grid sample always behave with nearest interpolation:
HANDLED_FUNCTIONS = {}
def implements(torch_function):
"Register an implementation of a torch function for a Tensor-like object."
def decorator(func):
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
class SegmentationMask(torch.Tensor):
def __torch_function__(self, func, types, args, kwargs):
if func not in HANDLED_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __torch_function__ to handle DiagonalTensor objects.
if not all(issubclass(t, self.__class__) for t in types):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
@implements(torch.nn.functional.interpolate)
def interpolate(...):
# force nearest interpolation
return torch.nn.functional.interpolate(..., mode='nearest')
and we can also give custom implementations for bounding boxes:
class BoxList(torch.Tensor):
# need to define height and width somewhere as an attribute
...
@implements(torch.nn.functional.interpolate)
def interpolate(...):
return box * scale
This would allow to cover the remaining of point 3. Because they are subclasses of torch.Tensor
, they behave as Tensor
except in some particular cases where we override the behavior.
This would be used as follows:
boxes = BoxList(torch.rand(10, 4), ...)
masks = SegmentationMask(torch.rand(1, 100, 200))
image = torch.rand(3, 100, 200)
# basically a dict of inputs
x = CombinedObject(image=image, boxes=boxes, masks=masks)
transforms = Compose([
RandomResizeCrop(224),
RandomFlip(),
ColorAugment(),
])
out = transforms(x)
# have an API for getting the elements back
image = out.get('image')
# or something like that
Be simple and modular
This is up for discussion. The fact that we are implementing subclasses that do not behave exactly like tensors can be confusing and misleading. But it does seem to simplify a number of things, and makes it possible for users to leverage the same abstractions in torchvision for their own custom types, without having to modify anything in torchvision, which is nice.
Related discussions
Some other proposals have been discussed in #230 #1169 and many other places.
cc @Noiredd @SebastienEske @pmeier for discussion