Skip to content

Commit 9b19f0f

Browse files
authored
Fix test_draw_boxes (#3631)
* new image * avoid check if pil version is < 8.2 as the reference image would be different
1 parent 37eb37a commit 9b19f0f

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed
0 Bytes
Loading

test/test_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import unittest
88
from io import BytesIO
99
import torchvision.transforms.functional as F
10-
from PIL import Image
10+
from PIL import Image, __version__ as PILLOW_VERSION
11+
12+
13+
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
1114

1215
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
1316
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
@@ -120,8 +123,11 @@ def test_draw_boxes(self):
120123
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
121124
res.save(path)
122125

123-
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
124-
self.assertTrue(torch.equal(result, expected))
126+
if PILLOW_VERSION >= (8, 2):
127+
# The reference image is only valid for new PIL versions
128+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
129+
self.assertTrue(torch.equal(result, expected))
130+
125131
# Check if modification is not in place
126132
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
127133
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

0 commit comments

Comments
 (0)