Skip to content

Input checks for interpolation parameter #7192

Open
@pmeier

Description

@pmeier

After #7176, we no longer warn and coerce integer inputs for interpolation into our enum. This means, passing an int will just fail down the line:

import torch
from torchvision import transforms

transform = transforms.Resize(size=(32, 64), interpolation=0)

transform(torch.rand(3, 256, 128))
Traceback (most recent call last):
  File "/home/philip/git/pytorch/vision/main.py", line 6, in <module>
    transform(torch.rand(3, 256, 128))
  File "/home/philip/.conda/envs/pytorch-vision-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philip/git/pytorch/vision/torchvision/transforms/transforms.py", line 336, in forward
    return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
  File "/home/philip/git/pytorch/vision/torchvision/transforms/functional.py", line 430, in resize
    raise TypeError("Argument interpolation should be a InterpolationMode")
TypeError: Argument interpolation should be a InterpolationMode

(Note that it is failing in the functional rather in the transforms part)

However, the above is only true for transforms v1. We don't have any such checks for transforms v2:

import torch
from torchvision.prototype import transforms

transform = transforms.Resize(size=(32, 64), interpolation=0)

transform(torch.rand(3, 256, 128))
Traceback (most recent call last):
  File "/home/philip/git/pytorch/vision/main.py", line 6, in <module>
    transform(torch.rand(3, 256, 128))
  File "/home/philip/.conda/envs/pytorch-vision-dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/_transform.py", line 40, in forward
    flat_outputs = [
  File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/_transform.py", line 41, in <listcomp>
    self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
  File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/_geometry.py", line 69, in _transform
    return F.resize(
  File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/functional/_geometry.py", line 243, in resize
    return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
  File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/functional/_geometry.py", line 170, in resize_image_tensor
    mode=interpolation.value,
AttributeError: 'int' object has no attribute 'value'

We won't get silent bugs here, but it is still worth it to have an expressive error message. Especially since int was allowed before.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions