Description
🐛 Bug
ColorJitter is supposed to be called on an image of type PIL or Tensor, but can only be called on images of type PIL.
To Reproduce
Steps to reproduce the behavior:
- Load the data by specifically composing a
ToTensor()
transformation followed by aColorJitter()
one. - Create a DataLoader using that dataset
- Try to loop through the loader
Code example that reproduces this bug:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
root = 'path/to/cifar/data'
color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
ds_train = CIFAR10(root, download=True, train=True, transform=transforms.Compose([transforms.ToTensor(), color_jitter]))
train_loader = DataLoader(ds_train, batch_size=128, num_workers=4, drop_last=True, shuffle=False)
for (images, label) in train_loader:
print("There is no Bug!")
Error message:
Traceback (most recent call last):
File "train.py", line 11, in <module>
for (images, label) in train_loader:
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 400, in __next__
data = self._next_data()
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1032, in _next_data
return self._process_data(data)
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1058, in _process_data
data.reraise()
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/_utils.py", line 420, in reraise
raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
data = fetcher.fetch(index)
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/datasets/cifar.py", line 120, in __getitem__
img = self.transform(img)
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 54, in __call__
img = t(img)
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 1092, in forward
img = F.adjust_hue(img, hue_factor)
File "/path/to/home/dir/.conda/envs/contrastive/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 728, in adjust_hue
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>
Expected behavior
Should be able to run without errors.
Environment
PyTorch version: 1.7.0.dev20200807
Is debug build: No
CUDA used to build PyTorch: 10.2
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] torch==1.7.0.dev20200807
[pip3] torchfile==0.1.0
[pip3] torchnet==0.0.4
[pip3] torchvision==0.8.0.dev20200807
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] mkl 2020.1 217
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.1.0 py37h23d657b_0
[conda] mkl_random 1.1.1 py37h0573a6f_0
[conda] numpy 1.19.1 py37hbc911f0_0
[conda] numpy-base 1.19.1 py37hfa32c7d_0
[conda] pytorch 1.7.0.dev20200807 py3.7_cuda10.2.89_cudnn7.6.5_0 pytorch-nightly
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchnet 0.0.4 pypi_0 pypi
[conda] torchvision 0.8.0.dev20200807 py37_cu102 pytorch-nightly
Additional context
By looking into the code of transforms/transforms.py
and transforms/functional.py
of the master branch of this repo, the description of forward
at line 1064 (in class ColorJitter
) of transforms.py
says:
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Input image.
Returns:
PIL Image or Tensor: Color jittered image.
"""
which contradicts the fact that ColorJitter
was placed in the "Transforms on PIL Image" category in the official documentation at https://pytorch.org/docs/stable/torchvision/transforms.html. Then forward
proceeds to call, on lines 1077, 1082, 1087 and 1092, the following functions in this order:
F.adjust_brightness(img, brightness_factor)
, F.adjust_contrast(img, contrast_factor)
, F.adjust_saturation(img, saturation_factor)
and F.adjust_hue(img, hue_factor)
.
Now, if we look into transforms/functional.py
, the first three functions handle both Tensor and PIL type cases. For example, at line 675, we have (in the case of adjust_brightness
):
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)
return F_t.adjust_brightness(img, brightness_factor)
However, it seems to have been forgotten to update the last one as well. Indeed, at line 742, in adjust_hue
, we have:
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
It would also be nice if other transformations, like RandomGrayscale, could support types of both Tensor and PIL as inputs (with the types of their outputs being consistent).
Pull Request
I opened a pull request (#2566) to attempt to address the reported bug.