Skip to content

[RFC] Abstractions for segmentation / detection transforms #1406

Open
@fmassa

Description

@fmassa

This is a proposal. I'm not sure yet it's the best way to achieve this, so I'm putting this up for discussion.

tl;dr

Have specific tensor subclasses for BoxList / SegmentationMask, CombinedObjects , etc, which inherits from torch.Tensor and overrides methods to properly dispatch to the relevant implementations. Depends on __torch_function__ from pytorch/pytorch#22402 (implemented in pytorch/pytorch#27064)

Background

For more than 2 years now, users have asked for ways of performing transformations of multiple inputs at the same time, for example for semantic segmentation or object detection #9

The recommended solution is to use the functional transforms in this case #230 #1169 , but for simple cases, this is a bit verbose.

Requirements

Ideally, we would want the following to be possible:

  1. work with a Compose style interface for simple cases
  2. support more than a single input of each type (for example, two images and one segmentation mask)
  3. support joint rotations / rescale with different hyperparameters for different input types (images can do bilinear interpolation, segmentation maps should do nearest interpolation)
  4. be simple and modular

Proposed solution

We define new classes for each type of object, which should all inherit from torch.Tensor, and implement / override a few specific methods. It might depend on __torch_function__ from pytorch/pytorch#22402 (implemented in pytorch/pytorch#27064)

Work with Compose-style

We propose to define a CombinedObjects (better names welcome), which is a collection of arbitrary objects (potentially named, but that's not a requirement).
Calling any of the methods in it should dispatch to the corresponding methods of its constituents. A basic example is below, I'll mention a few more points about it afterwards):

class CombinedObjects(object):
    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def hflip(self):
        result = {}
        for name, value in self.kwargs.items():
            result[name] = value.hflip()
        return type(self)(**result)

In this way, if the underlying objects follows the same protocol (i.e., implement the required functions), then this should allow to combine an arbitrary number of objects with a Compose API via

# example for flip
class RandomFlip(object):
    def __call__(self, x):
        # implementation stays almost the same
        # but now, `x` can be an Image, or a CombinedObject
        if random.random() > 0.5:
            x = x.hflip()
        return x

transforms = Compose([
    Resize(300),
    RandomFlip(),
    RandomColorAugment()
])

inputs = CombinedObjects(img1= x, img2=y, mask=z)
output = transforms(inputs)

which satisfies point 1 and 2 above, and part of point 3 (except for the different transformation hyperparameters for image / segmentation mask, which I'll cover next).

Different behavior for mask / boxes / images

In the same vein as the CombinedObject approach from the previous section, we would have subclasses of torch.Tensor for BoxList / SegmentationMask / etc which would override the behavior of specific functions so that they work as expected.

For example (and using code snippets from pytorch/pytorch#25629), we can define a class for segmentation masks where rotation / interpolation / grid sample always behave with nearest interpolation:

HANDLED_FUNCTIONS = {}

def implements(torch_function):
    "Register an implementation of a torch function for a Tensor-like object."
    def decorator(func):
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

class SegmentationMask(torch.Tensor):
    def __torch_function__(self, func, types, args, kwargs):
        if func not in HANDLED_FUNCTIONS:
            return NotImplemented
        # Note: this allows subclasses that don't override
        # __torch_function__ to handle DiagonalTensor objects.
        if not all(issubclass(t, self.__class__) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

@implements(torch.nn.functional.interpolate)
def interpolate(...):
    # force nearest interpolation
    return torch.nn.functional.interpolate(..., mode='nearest')

and we can also give custom implementations for bounding boxes:

class BoxList(torch.Tensor):
    # need to define height and width somewhere as an attribute
    ...

@implements(torch.nn.functional.interpolate)
def interpolate(...):
    return box * scale

This would allow to cover the remaining of point 3. Because they are subclasses of torch.Tensor, they behave as Tensor except in some particular cases where we override the behavior.
This would be used as follows:

boxes = BoxList(torch.rand(10, 4), ...)
masks = SegmentationMask(torch.rand(1, 100, 200))
image = torch.rand(3, 100, 200)
# basically a dict of inputs
x = CombinedObject(image=image, boxes=boxes, masks=masks)
transforms = Compose([
    RandomResizeCrop(224),
    RandomFlip(),
    ColorAugment(),
])
out = transforms(x)
# have an API for getting the elements back
image = out.get('image')
# or something like that

Be simple and modular

This is up for discussion. The fact that we are implementing subclasses that do not behave exactly like tensors can be confusing and misleading. But it does seem to simplify a number of things, and makes it possible for users to leverage the same abstractions in torchvision for their own custom types, without having to modify anything in torchvision, which is nice.

Related discussions

Some other proposals have been discussed in #230 #1169 and many other places.

cc @Noiredd @SebastienEske @pmeier for discussion

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions