Skip to content

Commit b8e4e60

Browse files
authored
Merge branch 'main' into update-gaussian-blur
2 parents eb8fba3 + bdc5556 commit b8e4e60

File tree

9 files changed

+134
-52
lines changed

9 files changed

+134
-52
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def reference_inputs_resize_image_tensor():
232232
make_image_loaders(extra_dims=[()]),
233233
[
234234
F.InterpolationMode.NEAREST,
235+
F.InterpolationMode.NEAREST_EXACT,
235236
F.InterpolationMode.BILINEAR,
236237
F.InterpolationMode.BICUBIC,
237238
],
@@ -881,6 +882,7 @@ def reference_inputs_resized_crop_image_tensor():
881882
make_image_loaders(extra_dims=[()]),
882883
[
883884
F.InterpolationMode.NEAREST,
885+
F.InterpolationMode.NEAREST_EXACT,
884886
F.InterpolationMode.BILINEAR,
885887
F.InterpolationMode.BICUBIC,
886888
],

test/test_functional_tensor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
)
2626
from torchvision.transforms import InterpolationMode
2727

28-
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
28+
NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
29+
InterpolationMode.NEAREST,
30+
InterpolationMode.NEAREST_EXACT,
31+
InterpolationMode.BILINEAR,
32+
InterpolationMode.BICUBIC,
33+
)
2934

3035

3136
@pytest.mark.parametrize("device", cpu_and_gpu())
@@ -506,7 +511,7 @@ def test_perspective_interpolation_warning():
506511
],
507512
)
508513
@pytest.mark.parametrize("max_size", [None, 34, 40, 1000])
509-
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST])
514+
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
510515
def test_resize(device, dt, size, max_size, interpolation):
511516

512517
if dt == torch.float16 and device == "cpu":
@@ -966,7 +971,7 @@ def test_pad(device, dt, pad, config):
966971

967972

968973
@pytest.mark.parametrize("device", cpu_and_gpu())
969-
@pytest.mark.parametrize("mode", [NEAREST, BILINEAR, BICUBIC])
974+
@pytest.mark.parametrize("mode", [NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC])
970975
def test_resized_crop(device, mode):
971976
# test values of F.resized_crop in several cases:
972977
# 1) resize to the same size, crop to the same size => should be identity

test/test_prototype_transforms.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,3 +1789,41 @@ def test__transform(self, mocker):
17891789
mock_resize.assert_called_with(
17901790
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
17911791
)
1792+
1793+
1794+
@pytest.mark.parametrize(
1795+
("dtype", "expected_dtypes"),
1796+
[
1797+
(
1798+
torch.float64,
1799+
{torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64},
1800+
),
1801+
(
1802+
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
1803+
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
1804+
),
1805+
],
1806+
)
1807+
def test_to_dtype(dtype, expected_dtypes):
1808+
sample = dict(
1809+
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
1810+
image=make_image(dtype=torch.uint8),
1811+
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY, dtype=torch.float32),
1812+
str="str",
1813+
int=0,
1814+
)
1815+
1816+
transform = transforms.ToDtype(dtype)
1817+
transformed_sample = transform(sample)
1818+
1819+
for key, value in sample.items():
1820+
value_type = type(value)
1821+
transformed_value = transformed_sample[key]
1822+
1823+
# make sure the transformation retains the type
1824+
assert isinstance(transformed_value, value_type)
1825+
1826+
if isinstance(value, torch.Tensor):
1827+
assert transformed_value.dtype is expected_dtypes[value_type]
1828+
else:
1829+
assert transformed_value is value

test/test_prototype_transforms_functional.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,3 +1037,14 @@ def test_to_image_pil(inpt, mode):
10371037
assert isinstance(output, PIL.Image.Image)
10381038

10391039
assert np.asarray(inpt).sum() == np.asarray(output).sum()
1040+
1041+
1042+
def test_equalize_image_tensor_edge_cases():
1043+
inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
1044+
output = F.equalize_image_tensor(inpt)
1045+
torch.testing.assert_close(inpt, output)
1046+
1047+
inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
1048+
inpt[..., 100:, 100:] = 1
1049+
output = F.equalize_image_tensor(inpt)
1050+
assert output.unique().tolist() == [0, 255]

test/test_transforms_tensor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from torchvision.transforms import functional as F, InterpolationMode
2121
from torchvision.transforms.autoaugment import _apply_op
2222

23-
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
23+
NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
24+
InterpolationMode.NEAREST,
25+
InterpolationMode.NEAREST_EXACT,
26+
InterpolationMode.BILINEAR,
27+
InterpolationMode.BICUBIC,
28+
)
2429

2530

2631
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
@@ -378,7 +383,7 @@ def test_resize_int(self, size):
378383
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64])
379384
@pytest.mark.parametrize("size", [[32], [32, 32], (32, 32), [34, 35]])
380385
@pytest.mark.parametrize("max_size", [None, 35, 1000])
381-
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST])
386+
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
382387
def test_resize_scripted(self, dt, size, max_size, interpolation, device):
383388
tensor, _ = _create_data(height=34, width=36, device=device)
384389
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
@@ -402,12 +407,12 @@ def test_resize_save_load(self, tmpdir):
402407
@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
403408
@pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]])
404409
@pytest.mark.parametrize("size", [(32,), [44], [32], [32, 32], (32, 32), [44, 55]])
405-
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
410+
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC, NEAREST_EXACT])
406411
@pytest.mark.parametrize("antialias", [None, True, False])
407412
def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device):
408413

409-
if antialias and interpolation == NEAREST:
410-
pytest.skip("Can not resize if interpolation mode is NEAREST and antialias=True")
414+
if antialias and interpolation in {NEAREST, NEAREST_EXACT}:
415+
pytest.skip(f"Can not resize if interpolation mode is {interpolation} and antialias=True")
411416

412417
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
413418
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

torchvision/prototype/transforms/_misc.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
from collections import defaultdict
23
from typing import Any, Callable, Dict, Sequence, Type, Union
34

45
import PIL.Image
@@ -144,14 +145,22 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
144145
return F.gaussian_blur(inpt, self.kernel_size, **params)
145146

146147

147-
# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
148-
class ToDtype(Lambda):
149-
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
148+
class ToDtype(Transform):
149+
_transformed_types = (torch.Tensor,)
150+
151+
def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
152+
return dtype
153+
154+
def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
155+
super().__init__()
156+
if not isinstance(dtype, dict):
157+
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
158+
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
159+
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
150160
self.dtype = dtype
151-
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
152161

153-
def extra_repr(self) -> str:
154-
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
162+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
163+
return inpt.to(self.dtype[type(inpt)])
155164

156165

157166
class RemoveSmallBoundingBoxes(Transform):

torchvision/prototype/transforms/functional/_color.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,28 +183,37 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183183
return autocontrast_image_pil(inpt)
184184

185185

186-
def _scale_channel(img_chan: torch.Tensor) -> torch.Tensor:
187-
# TODO: we should expect bincount to always be faster than histc, but this
188-
# isn't always the case. Once
189-
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
190-
# block and only use bincount.
191-
if img_chan.is_cuda:
192-
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
193-
else:
194-
hist = torch.bincount(img_chan.view(-1), minlength=256)
195-
196-
nonzero_hist = hist[hist != 0]
197-
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
198-
if step == 0:
199-
return img_chan
200-
201-
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
202-
# Doing inplace clamp and converting lut to uint8 improves perfs
203-
lut.clamp_(0, 255)
204-
lut = lut.to(torch.uint8)
205-
lut = torch.nn.functional.pad(lut[:-1], [1, 0])
206-
207-
return lut[img_chan.to(torch.int64)]
186+
def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
187+
# input img shape should be [N, H, W]
188+
shape = img.shape
189+
# Compute image histogram:
190+
flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
191+
hist = flat_img.new_zeros(shape[0], 256)
192+
hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img))
193+
194+
# Compute image cdf
195+
chist = hist.cumsum_(dim=1)
196+
# Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
197+
# Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
198+
idx = chist.argmax(dim=1).sub_(1)
199+
# If histogram is degenerate (hist of zero image), index is -1
200+
neg_idx_mask = idx < 0
201+
idx.clamp_(min=0)
202+
step = chist.gather(dim=1, index=idx.unsqueeze(1))
203+
step[neg_idx_mask] = 0
204+
step.div_(255, rounding_mode="floor")
205+
206+
# Compute batched Look-up-table:
207+
# Necessary to avoid an integer division by zero, which raises
208+
clamped_step = step.clamp(min=1)
209+
chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255)
210+
lut = chist.to(torch.uint8) # [N, 256]
211+
212+
# Pad lut with zeros
213+
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
214+
lut = torch.cat([zeros, lut[:, :-1]], dim=1)
215+
216+
return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img))
208217

209218

210219
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
@@ -217,10 +226,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
217226

218227
if image.numel() == 0:
219228
return image
220-
elif image.ndim == 2:
221-
return _scale_channel(image)
222-
else:
223-
return torch.stack([_scale_channel(x) for x in image.view(-1, height, width)]).view(image.shape)
229+
230+
return _equalize_image_tensor_vec(image.view(-1, height, width)).view(image.shape)
224231

225232

226233
equalize_image_pil = _FP.equalize

torchvision/transforms/functional.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020

2121
class InterpolationMode(Enum):
2222
"""Interpolation modes
23-
Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
23+
Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
24+
and ``lanczos``.
2425
"""
2526

2627
NEAREST = "nearest"
28+
NEAREST_EXACT = "nearest-exact"
2729
BILINEAR = "bilinear"
2830
BICUBIC = "bicubic"
2931
# For PIL compatibility
@@ -50,6 +52,7 @@ def _interpolation_modes_from_int(i: int) -> InterpolationMode:
5052
InterpolationMode.NEAREST: 0,
5153
InterpolationMode.BILINEAR: 2,
5254
InterpolationMode.BICUBIC: 3,
55+
InterpolationMode.NEAREST_EXACT: 0,
5356
InterpolationMode.BOX: 4,
5457
InterpolationMode.HAMMING: 5,
5558
InterpolationMode.LANCZOS: 1,
@@ -416,7 +419,8 @@ def resize(
416419
interpolation (InterpolationMode): Desired interpolation enum defined by
417420
:class:`torchvision.transforms.InterpolationMode`.
418421
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
419-
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
422+
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
423+
supported.
420424
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
421425
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
422426
max_size (int, optional): The maximum allowed for the longer edge of
@@ -617,7 +621,8 @@ def resized_crop(
617621
interpolation (InterpolationMode): Desired interpolation enum defined by
618622
:class:`torchvision.transforms.InterpolationMode`.
619623
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
620-
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
624+
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
625+
supported.
621626
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
622627
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
623628
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias

torchvision/transforms/transforms.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ class Resize(torch.nn.Module):
296296
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
297297
interpolation (InterpolationMode): Desired interpolation enum defined by
298298
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
299-
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
300-
``InterpolationMode.BICUBIC`` are supported.
299+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
300+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
301301
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
302302
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
303303
max_size (int, optional): The maximum allowed for the longer edge of
@@ -865,8 +865,8 @@ class RandomResizedCrop(torch.nn.Module):
865865
resizing.
866866
interpolation (InterpolationMode): Desired interpolation enum defined by
867867
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
868-
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
869-
``InterpolationMode.BICUBIC`` are supported.
868+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
869+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
870870
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
871871
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
872872
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
@@ -2133,9 +2133,9 @@ def forward(self, tensor: Tensor) -> Tensor:
21332133
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
21342134

21352135
def __repr__(self):
2136-
format_string = self.__class__.__name__ + "(alpha="
2137-
format_string += str(self.alpha) + ")"
2138-
format_string += ", (sigma=" + str(self.sigma) + ")"
2139-
format_string += ", interpolation={self.interpolation}"
2140-
format_string += ", fill={self.fill})"
2136+
format_string = self.__class__.__name__
2137+
format_string += f"(alpha={self.alpha}"
2138+
format_string += f", sigma={self.sigma}"
2139+
format_string += f", interpolation={self.interpolation}"
2140+
format_string += f", fill={self.fill})"
21412141
return format_string

0 commit comments

Comments
 (0)