Skip to content

Commit 47800d4

Browse files
committed
Merge branch 'master' of github.com:pytorch/vision
2 parents 4d7f70b + ad1dac4 commit 47800d4

File tree

8 files changed

+72
-22
lines changed

8 files changed

+72
-22
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ to be atleast 224.
280280
The images have to be loaded in to a range of [0, 1] and then
281281
normalized using `mean=[0.485, 0.456, 0.406]` and `std=[0.229, 0.224, 0.225]`
282282

283-
An example of such normalization can be found in `the imagenet example here` <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>
283+
An example of such normalization can be found in the imagenet example `here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`__
284284

285285
Transforms
286286
==========
@@ -410,7 +410,7 @@ computing the ``(min, max)`` over all images.
410410

411411
``pad_value=<float>`` sets the value for the padded pixels.
412412

413-
`Example usage is given in this notebook` <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>
413+
Example usage is given in this `notebook <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`__
414414

415415
``save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)``
416416
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
requirements = [
1313
'numpy',
14-
'pillow',
14+
'pillow >= 4.1.1',
1515
'six',
1616
'torch',
1717
]

test/test_transforms.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ def test_scale(self):
8686
owidth = random.randint(5, 12) * 2
8787
result = transforms.Compose([
8888
transforms.ToPILImage(),
89-
transforms.Scale((owidth, oheight)),
89+
transforms.Scale((oheight, owidth)),
9090
transforms.ToTensor(),
9191
])(img)
9292
assert result.size(1) == oheight
9393
assert result.size(2) == owidth
9494

9595
result = transforms.Compose([
9696
transforms.ToPILImage(),
97-
transforms.Scale([owidth, oheight]),
97+
transforms.Scale([oheight, owidth]),
9898
transforms.ToTensor(),
9999
])(img)
100100
assert result.size(1) == oheight
@@ -150,7 +150,7 @@ def test_pad_with_tuple_of_pad_values(self):
150150
assert output.size[0] == width + padding[0] + padding[2]
151151
assert output.size[1] == height + padding[1] + padding[3]
152152

153-
def test_pad_raises_with_invalide_pad_sequence_len(self):
153+
def test_pad_raises_with_invalid_pad_sequence_len(self):
154154
with self.assertRaises(ValueError):
155155
transforms.Pad(())
156156

@@ -264,6 +264,22 @@ def test_tensor_gray_to_pil_image(self):
264264
assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy())
265265
assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy())
266266

267+
def test_tensor_rgba_to_pil_image(self):
268+
trans = transforms.ToPILImage()
269+
to_tensor = transforms.ToTensor()
270+
271+
img_data = torch.Tensor(4, 4, 4).uniform_()
272+
img = trans(img_data)
273+
assert img.mode == 'RGBA'
274+
assert img.getbands() == ('R', 'G', 'B', 'A')
275+
r, g, b, a = img.split()
276+
277+
expected_output = img_data.mul(255).int().float().div(255)
278+
assert np.allclose(expected_output[0].numpy(), to_tensor(r).numpy())
279+
assert np.allclose(expected_output[1].numpy(), to_tensor(g).numpy())
280+
assert np.allclose(expected_output[2].numpy(), to_tensor(b).numpy())
281+
assert np.allclose(expected_output[3].numpy(), to_tensor(a).numpy())
282+
267283
def test_ndarray_to_pil_image(self):
268284
trans = transforms.ToPILImage()
269285
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()

test/test_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torchvision.utils as utils
3+
import unittest
4+
5+
6+
class Tester(unittest.TestCase):
7+
8+
def test_make_grid_not_inplace(self):
9+
t = torch.rand(5, 3, 10, 10)
10+
t_clone = t.clone()
11+
12+
utils.make_grid(t, normalize=False)
13+
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
14+
15+
utils.make_grid(t, normalize=True, scale_each=False)
16+
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
17+
18+
utils.make_grid(t, normalize=True, scale_each=True)
19+
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
20+
21+
def test_make_grid_raises_with_variable(self):
22+
t = torch.autograd.Variable(torch.rand(3, 10, 10))
23+
with self.assertRaises(TypeError):
24+
utils.make_grid(t)
25+
26+
with self.assertRaises(TypeError):
27+
utils.make_grid([t, t, t, t])
28+
29+
30+
if __name__ == '__main__':
31+
unittest.main()

torchvision/datasets/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def check_integrity(fpath, md5):
1010
md5o = hashlib.md5()
1111
with open(fpath, 'rb') as f:
1212
# read in 1MB chunks
13-
for chunk in iter(lambda: f.read(1024 * 1024 * 1024), b''):
13+
for chunk in iter(lambda: f.read(1024 * 1024), b''):
1414
md5o.update(chunk)
1515
md5c = md5o.hexdigest()
1616
if md5c != md5:

torchvision/models/densenet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99

1010
model_urls = {
11-
'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth',
12-
'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth',
13-
'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth',
14-
'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth',
11+
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
12+
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
13+
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
14+
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
1515
}
1616

1717

torchvision/transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def to_pil_image(pic):
109109
mode = 'I'
110110
elif npimg.dtype == np.float32:
111111
mode = 'F'
112+
elif npimg.shape[2] == 4:
113+
if npimg.dtype == np.uint8:
114+
mode = 'RGBA'
112115
else:
113116
if npimg.dtype == np.uint8:
114117
mode = 'RGB'

torchvision/utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def make_grid(tensor, nrow=8, padding=2,
1010
Args:
1111
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
1212
or a list of images all of the same size.
13-
nrow (int, optional): Number of rows in grid. Final grid size is
14-
(B / nrow, nrow). Default is 8.
13+
nrow (int, optional): Number of images displayed in each row of the grid.
14+
The Final grid size is (B / nrow, nrow). Default is 8.
1515
padding (int, optional): amount of padding. Default is 2.
1616
normalize (bool, optional): If True, shift the image to the range (0, 1),
1717
by subtracting the minimum and dividing by the maximum pixel value.
@@ -26,14 +26,13 @@ def make_grid(tensor, nrow=8, padding=2,
2626
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
2727
2828
"""
29+
if not (torch.is_tensor(tensor) or
30+
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
31+
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
32+
2933
# if list of tensors, convert to a 4D mini-batch Tensor
3034
if isinstance(tensor, list):
31-
tensorlist = tensor
32-
numImages = len(tensorlist)
33-
size = torch.Size(torch.Size([numImages]) + tensorlist[0].size())
34-
tensor = tensorlist[0].new(size)
35-
for i in irange(numImages):
36-
tensor[i].copy_(tensorlist[i])
35+
tensor = torch.stack(tensor, dim=0)
3736

3837
if tensor.dim() == 2: # single image H x W
3938
tensor = tensor.view(1, tensor.size(0), tensor.size(1))
@@ -45,6 +44,7 @@ def make_grid(tensor, nrow=8, padding=2,
4544
tensor = torch.cat((tensor, tensor, tensor), 1)
4645

4746
if normalize is True:
47+
tensor = tensor.clone() # avoid modifying tensor in-place
4848
if range is not None:
4949
assert isinstance(range, tuple), \
5050
"range has to be a tuple (min, max) if specified. min and max are numbers"
@@ -70,14 +70,14 @@ def norm_range(t, range):
7070
xmaps = min(nrow, nmaps)
7171
ymaps = int(math.ceil(float(nmaps) / xmaps))
7272
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
73-
grid = tensor.new(3, height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2).fill_(pad_value)
73+
grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value)
7474
k = 0
7575
for y in irange(ymaps):
7676
for x in irange(xmaps):
7777
if k >= nmaps:
7878
break
79-
grid.narrow(1, y * height + 1 + padding // 2, height - padding)\
80-
.narrow(2, x * width + 1 + padding // 2, width - padding)\
79+
grid.narrow(1, y * height + padding, height - padding)\
80+
.narrow(2, x * width + padding, width - padding)\
8181
.copy_(tensor[k])
8282
k = k + 1
8383
return grid

0 commit comments

Comments
 (0)