Skip to content

Transforms V2 proposal: Enabling reproducible workflows via local RNGs #7027

Open
@rsokl

Description

@rsokl

🚀 The feature

(This was originally pitched in this long feedback thread. It was recommended that I open a separate issue).

Enable the new transforms API to support the use of local generators to control RNG via the modified API:

from torch import Generator, Tensor, default_generator
import torch.nn as nn

class Transform(nn.Module):
    def _get_params(self, flat_inputs: List[Any], *, generator: Generator) -> Dict[str, Any]:
        ...

    def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
        # the only modification
        params = self._get_params(flat_inputs, generator=generator)

Thus transforms that implement _get_params would replace calls like

# e.g. replace calls like
angle = float(torch.empty(1).uniform_(0.0, 180.).item())

with

# specifying the device is, unfortunately, necessary: https://github.com/pytorch/pytorch/issues/79018
angle = float(torch.empty(1, device=generator.device).uniform_(0.0, 180., generator=generator).item())

A transform like Compose would have to be modified as well. Currently, it supports a sequence of callables that are assumed to accept a single positional argument. It could be assumed that only instances of Transform involve stochasticity and will be passed the random generator. In this case, Compose would look like:

class Compose(Transform):
    # __init__ is unchanged

    def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
        for transform in self.transforms:
            sample = transform(sample) if not isinstance(transform, Transform) else transform(sample, generator=generator)
        return sample

It would be straightforward to document this behavior to users – that only instances of Transform are passed the generator – so that they know how to opt-in to having the generator be passed to their custom transforms. And, again, this would be compatible with the old nn.Module transforms.

An example of this in practice would be:

from torch import Generator

rng = Generator.manual_seed(0)

trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels, generator=rng)

Another nice thing about this is that specific fail cases that occur during training/testing can be reproduced in an isolated way; _get_params(dummy_img, generator=rng) can be used to iterate the generator's state to "replay" a sequence of transformations without have to redo all of the compute. Whereas this would not work if the model and the transforms both affect and derive from global state.

Motivation, pitch

In recent years, NumPy has completely revised their PRNG API to avoid global random state (here is a great post on good practices with NumPy's generators). JAX avoids mutable RNG objects altogether. PyTorch provides torch.Generator to users to to make randomness local and "non-spooky", but many libraries prevent users from utilizing this capability.

I am proposing that Transform enable users to optionally pass in a Generator to the forward pass so that torchvision transform pipelines can be made to be isolated from global entropy and thus support more reproducible workflows. This reproducibility is especially useful in the context of performing testing & evaluation – the specific sequence of data transformations performed should be able to be isolated from whether or not a model is using dropout in its forward pass.

Alternatives

No response

Additional context

@pmeier already provided (positive) feedback on this proposal here

cc @vfdev-5 @datumbox @bjuncek @pmeier

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions