Skip to content

Proposal for extending transforms #230

Closed
@fmassa

Description

@fmassa

Up to now, we compose different transforms via Compose transform, which works like nn.Sequential but for transforms. Each transform is applied independently to input and target.
While simple and efficient, there has been an increasing demand on extending the available transforms to accept both input and target, see #9, #115, #221 for some examples. This would allow performing the same random transformations in both the input and the target, for tasks such as semantic segmentation.

There are a few possible approaches and I'll summarize the ones that have been mentioned in here:

  1. provide a set of random transforms working on pairs (or triplets, etc) of images, as proposed for example in [add] A PairRandomCrop for both input and target. #221. There are a couple of downsides of this approach: (a) it doesn't scale when we want to combine images coming from different domains (in object detection, we have image and bounding boxes for example, and we need to reimplement a new transform for each pair, introducing some redundancy in the code), (b) hard-coded to work with 2 inputs (but could be extended to n-inputs without difficulties)
  2. factor out the randomness in the transforms, so that we can use the same functionality we currently have, just needing to add a generate_seed call in __getitem__, as was proposed in Separate random generation from transforms #115 . The drawback from this approach is that in some cases, we need the information of input to be able to perform the transformation in target (imagine flipping a bounding box horizontally, it requires the image width, which is not available right away).
  3. provide a generic base class that samples the random transformation parameters, and passes those parameters to the input arguments, as proposed by @chsasank in the slack channel. While this is an improvement over both [add] A PairRandomCrop for both input and target. #221 and Separate random generation from transforms #115, it still suffers from the same limitation of Separate random generation from transforms #115 as we still only pass one input argument at a time.

Also, all those options actually handle each input independently. Furthermore, for both 1. and 3., the order of the operations is fixed in the dataset class (first input transforms, then target transforms, then joint transforms, for example). There might be cases where it would be convenient to not be restricted to such orderings, and alternate between single transform and joint transform.

One possibility to address those issues would be to provide a set of functions like nn.Split, nn.Concat, nn.Select, etc, so that we can always pass all inputs to the transforms, and let the transforms be implemented by the user explicitly. This mimics the legacy nn behavior.
The downside of this approach is that it gets very complicated to write some transformations, and this doesn't buy us much.

Instead, one simpler approach (which has already been advocated in the past by @colesbury) is to let the user directly subclass the dataset and implement their complex transforms in there. This approach (4.) would look something like

class VOCDatasetSegmentation(VOCDataset):
    def __init__(self, flip=False, **kwargs):
        super(VOCDatasetSegmentation, self).__init__(**kwargs)
        self.flip = flip

    def __getitem__(self, idx):
        image, target = super(VOCDatasetSegmentation, self).__getitem__(idx)
        do_flip = np.random.random() > 0.5
        if self.flip and do_flip:
            # flip image and bbox here
        return image, target

A downside is that we can't easily re-use those transforms in a different Dataset class (COCO for example).

The question now is if there is an intermediate design that we could leverage that keeps the simplicity of 4., without having to subclass the dataset and reimplement every time the same transforms?

What about yet another possibility (5.) would be to let the user write the code of their as follows

class Dataset(object):
    def __init__(self, transforms=None):
        self.transforms = transforms
    def __getitem__(self, idx):
        # get image1, image2, bounding_box
        # the transforms takes all inputs into account
        if self.transforms:
            image1, image2, bounding_box = self.transforms(image1, image2, bounding_box)
        return image1, image2, bounding_box

from torchvision.transforms import random_horizontal_flip

class MyJointRandomFlipTransform(object):
    def __call__(self, image1, image2, bounding_box):
        # provide a functional interface for the current transforms
        # so that they can be easily reused, and have the parameters
        # of the transformation if needed
        image1, params = random_horizontal_flip(image1, return_params=True)
        # reuses the same transformations, if wanted
        image2 = random_horizontal_flip(image2, params=params)
        # no transformation in torchvision for bounding_box, have to do it
        # ourselves
        if params.flip:
            bounding_box[:, 1] = image1.size(2) - bounding_box[:, 1]
            bounding_box[:, 3] = image1.size(2) - bounding_box[:, 3]
        return image1, image2, bounding_box

In this way, we have the flexibility subclassing the dataset, while being more modular and easy to implement.
There would be some differences in the way we write our datasets currently, but we could have a fallback implementation for backward compatibility, in the lines of

class StandardTransform(object):
    def __init__(self, transform, target_transform):
        self.transform = transform
        self.target_transform = target_transform

   def __call__(self, input, target):
        if self.transform:
            input = self.transform(input)
        if self.target_transform:
            target = self.target_transform(target)
        return input, target

and we would replace in the current datasets transform and target_transform by a single transforms, while keeping the old behavior

class Dataset(object):
    def __init__(self, path, transforms=None, transform=None, target_transform=None):
        # assert that only transforms or (transform, target_transform) can be set at a time
        if transforms is None:
           transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms

    # getitem only uses transforms from now on

What do you think? Do you see drawbacks on using such an approach?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions