Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
696c15a
Fill color support for tensor affine transforms
voldemortX Oct 27, 2020
229c140
Merge branch 'master' into issue2887
voldemortX Oct 27, 2020
adae0f6
PEP fix
voldemortX Oct 27, 2020
b2721e8
Merge branch 'issue2887' of github.com:voldemortX/vision into issue2887
voldemortX Oct 27, 2020
1c4e48a
Docstring changes and float support
voldemortX Oct 28, 2020
62abb37
Docstring update for transforms and float type cast
voldemortX Oct 28, 2020
a585dbd
Cast only for Tensor
voldemortX Oct 28, 2020
d616210
Temporary patch for lack of Union type support, plus an extra unit test
voldemortX Oct 31, 2020
417f6ea
More plausible bilinear filling for tensors
voldemortX Nov 3, 2020
50d311d
Keep things simple & New docstrings
voldemortX Nov 5, 2020
6b0eb53
Merge branch 'master' into issue2887
voldemortX Nov 30, 2020
5589c14
Fix lint and other issues after merge
voldemortX Nov 30, 2020
731a5a9
make it in one line
voldemortX Nov 30, 2020
4389f80
Merge branch 'master' into issue2887
vfdev-5 Nov 30, 2020
2ea1003
Docstring and some code modifications
voldemortX Nov 30, 2020
4c59964
Merge branch 'issue2887' of github.com:voldemortX/vision into issue2887
voldemortX Nov 30, 2020
9e7cb7a
More tests and corresponding changes for transoforms and docstring ch…
voldemortX Dec 1, 2020
16e9b97
Simplify test configs
voldemortX Dec 1, 2020
96c70bc
Update test_functional_tensor.py
vfdev-5 Dec 1, 2020
9d9fd08
Update test_functional_tensor.py
vfdev-5 Dec 2, 2020
87560cb
Merge branch 'master' into issue2887
vfdev-5 Dec 2, 2020
bc7e9fe
Move assertions
voldemortX Dec 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 62 additions & 59 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,24 +539,24 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine):
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
# 4) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [-5, 4], 1.2, [0.0, 0.0]),
(33, (-4, -8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(-25, [0, 0], 1.2, [0.0, 15.0]),
(-45, [-10, 0], 0.7, [2.0, 5.0]),
(-45, [-10, -10], 1.2, [4.0, 5.0]),
(-90, [0, 0], 1.0, [0.0, 0.0]),
(45, [5, 6], 1.0, [0.0, 0.0], None),
(33, (5, -4), 1.0, [0.0, 0.0], 0),
(45, [-5, 4], 1.2, [0.0, 0.0], 7),
(33, (-4, -8), 2.0, [0.0, 0.0], 255),
(85, (10, -10), 0.7, [0.0, 0.0], 0),
(0, [0, 0], 1.0, [35.0, ], 0),
(-25, [0, 0], 1.2, [0.0, 15.0], 0),
(-45, [-10, 0], 0.7, [2.0, 5.0], 0),
(-45, [-10, -10], 1.2, [4.0, 5.0], 0),
(-90, [0, 0], 1.0, [0.0, 0.0], 0),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
for a, t, s, sh, f in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r, fillcolor=f)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu()
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r, fillcolor=f).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
Expand All @@ -569,7 +569,7 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
ratio_diff_pixels,
tol,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
(r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)

Expand Down Expand Up @@ -612,35 +612,36 @@ def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:

out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)

self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
for f in [None, 7, 0, 255]:

out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c, fill=f)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c, fill=f).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)

self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
)
)
)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c),
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c, f),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
)

def test_rotate(self):
# Tests on square image
Expand Down Expand Up @@ -678,30 +679,32 @@ def test_rotate(self):

def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
dt = tensor.dtype
for r in [0, ]:
for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for f in [None, 0, 7, 255]:
for r in [0, ]:
for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r,
fill=f)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

for fn in [F.perspective, scripted_transform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
for fn in [F.perspective, scripted_transform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)

num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, dt, spoints, epoints),
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
0.05,
msg="{}: {}\n{} vs \n{}".format(
(f, r, dt, spoints, epoints),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
)

def test_perspective(self):

Expand Down
29 changes: 23 additions & 6 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,8 +872,8 @@ def _assert_grid_transform_inputs(
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")

if fillcolor is not None:
warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero")
if fillcolor is not None and not isinstance(fillcolor, int):
warnings.warn("Argument fill/fillcolor should be an integer")

if resample not in _interpolation_modes:
raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample))
Expand Down Expand Up @@ -905,15 +905,32 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
return img


def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[int]) -> Tensor:

img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, grid.dtype)

if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])

if fill is None:
fill = 0

# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
need_mask = False
if fill != 0:
need_mask = True
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, dummy), dim=1)

img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

# Fill with required color
if need_mask:
mask = img[:, -1, :, :] < 0.5 # Safer, but linear interpolations should not create numbers other than 0/1
img = img[:, :-1, :, :]
img[mask.expand_as(img)] = fill

img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img

Expand Down Expand Up @@ -974,7 +991,7 @@ def affine(
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
mode = _interpolation_modes[resample]
return _apply_grid_transform(img, grid, mode)
return _apply_grid_transform(img, grid, mode, fill=fillcolor)


def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
Expand Down Expand Up @@ -1045,7 +1062,7 @@ def rotate(
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample]

return _apply_grid_transform(img, grid, mode)
return _apply_grid_transform(img, grid, mode, fill=fill)


def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
Expand Down Expand Up @@ -1123,7 +1140,7 @@ def perspective(
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
mode = _interpolation_modes[interpolation]

return _apply_grid_transform(img, grid, mode)
return _apply_grid_transform(img, grid, mode, fill=fill)


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
Expand Down