Description
🚀 The feature
(This was originally pitched in this long feedback thread. It was recommended that I open a separate issue).
Enable the new transforms API to support the use of local generators to control RNG via the modified API:
from torch import Generator, Tensor, default_generator
import torch.nn as nn
class Transform(nn.Module):
def _get_params(self, flat_inputs: List[Any], *, generator: Generator) -> Dict[str, Any]:
...
def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
# the only modification
params = self._get_params(flat_inputs, generator=generator)
Thus transforms that implement _get_params
would replace calls like
# e.g. replace calls like
angle = float(torch.empty(1).uniform_(0.0, 180.).item())
with
# specifying the device is, unfortunately, necessary: https://github.com/pytorch/pytorch/issues/79018
angle = float(torch.empty(1, device=generator.device).uniform_(0.0, 180., generator=generator).item())
A transform like Compose
would have to be modified as well. Currently, it supports a sequence of callables that are assumed to accept a single positional argument. It could be assumed that only instances of Transform
involve stochasticity and will be passed the random generator. In this case, Compose
would look like:
class Compose(Transform):
# __init__ is unchanged
def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample) if not isinstance(transform, Transform) else transform(sample, generator=generator)
return sample
It would be straightforward to document this behavior to users – that only instances of Transform
are passed the generator – so that they know how to opt-in to having the generator be passed to their custom transforms. And, again, this would be compatible with the old nn.Module
transforms.
An example of this in practice would be:
from torch import Generator
rng = Generator.manual_seed(0)
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels, generator=rng)
Another nice thing about this is that specific fail cases that occur during training/testing can be reproduced in an isolated way; _get_params(dummy_img, generator=rng)
can be used to iterate the generator's state to "replay" a sequence of transformations without have to redo all of the compute. Whereas this would not work if the model and the transforms both affect and derive from global state.
Motivation, pitch
In recent years, NumPy has completely revised their PRNG API to avoid global random state (here is a great post on good practices with NumPy's generators). JAX avoids mutable RNG objects altogether. PyTorch provides torch.Generator
to users to to make randomness local and "non-spooky", but many libraries prevent users from utilizing this capability.
I am proposing that Transform
enable users to optionally pass in a Generator
to the forward pass so that torchvision transform pipelines can be made to be isolated from global entropy and thus support more reproducible workflows. This reproducibility is especially useful in the context of performing testing & evaluation – the specific sequence of data transformations performed should be able to be isolated from whether or not a model is using dropout in its forward pass.
Alternatives
No response
Additional context
@pmeier already provided (positive) feedback on this proposal here