From c2711148f598f2328366eec58dbaf53582b207df Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Nov 2022 18:43:57 +0000 Subject: [PATCH 1/4] Refactor gaussian_blur --- .../prototype/transforms/functional/_misc.py | 57 +++++++++++-------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 738e369962d..222498e312b 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -89,7 +89,7 @@ def gaussian_blur_image_tensor( # TODO: consider deprecating integers from sigma on the future if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] - if len(kernel_size) != 2: + elif len(kernel_size) != 2: raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") for ksize in kernel_size: if ksize % 2 == 0 or ksize < 0: @@ -97,15 +97,19 @@ def gaussian_blur_image_tensor( if sigma is None: sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] - - if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): - raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") - if isinstance(sigma, (int, float)): - sigma = [float(sigma), float(sigma)] - if isinstance(sigma, (list, tuple)) and len(sigma) == 1: - sigma = [sigma[0], sigma[0]] - if len(sigma) != 2: - raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}") + else: + if isinstance(sigma, (int, float)): + s = float(sigma) + sigma = [s, s] + elif isinstance(sigma, (list, tuple)): + length = len(sigma) + if length == 1: + s = float(sigma[0]) + sigma = [s, s] + elif length != 2: + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}") + else: + raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") for s in sigma: if s <= 0.0: raise ValueError(f"sigma should have positive values. Got {sigma}") @@ -113,31 +117,34 @@ def gaussian_blur_image_tensor( if image.numel() == 0: return image + dtype = image.dtype shape = image.shape - - if image.ndim > 4: + ndim = len(shape) + if ndim > 4: image = image.reshape((-1,) + shape[-3:]) - needs_unsquash = True - else: - needs_unsquash = False + elif ndim == 3: + image = image.unsqueeze(dim=0) + elif ndim < 3: + raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.") - dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device) - kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype]) + fp = torch.is_floating_point(image) + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device) + kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + output = image if fp else image.to(torch.float32) # padding = (left, right, top, bottom) padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] - output = torch_pad(image, padding, mode="reflect") + output = torch_pad(output, padding, mode="reflect") output = conv2d(output, kernel, groups=output.shape[-3]) - output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype) - - if needs_unsquash: - output = output.reshape(shape) + if output.dtype != dtype: + if not fp: + output.round_() + output = output.to(dtype) - return output + return output.reshape(shape) @torch.jit.unused From 48f40dc0033b748c05dfcdf0e864b7208951035c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Nov 2022 18:50:50 +0000 Subject: [PATCH 2/4] Add conditional reshape --- torchvision/prototype/transforms/functional/_misc.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 222498e312b..7b5d8825b74 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -144,7 +144,12 @@ def gaussian_blur_image_tensor( output.round_() output = output.to(dtype) - return output.reshape(shape) + if ndim > 4: + output = output.reshape(shape) + elif ndim == 3: + output = output.squeeze(dim=0) + + return output @torch.jit.unused From 826167d3877dcfa3435a82309f0be5f7c00132d7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Nov 2022 12:34:24 +0000 Subject: [PATCH 3/4] Further refactoring --- .../prototype/transforms/functional/_misc.py | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 7b5d8825b74..04b8959190f 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -68,9 +68,9 @@ def normalize( def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: - lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma) + lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma) x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) - kernel1d = torch.softmax(-x.pow_(2), dim=0) + kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0) return kernel1d @@ -98,16 +98,16 @@ def gaussian_blur_image_tensor( if sigma is None: sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] else: - if isinstance(sigma, (int, float)): - s = float(sigma) - sigma = [s, s] - elif isinstance(sigma, (list, tuple)): + if isinstance(sigma, (list, tuple)): length = len(sigma) if length == 1: s = float(sigma[0]) sigma = [s, s] elif length != 2: - raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}") + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}") + elif isinstance(sigma, (int, float)): + s = float(sigma) + sigma = [s, s] else: raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") for s in sigma: @@ -119,35 +119,30 @@ def gaussian_blur_image_tensor( dtype = image.dtype shape = image.shape - ndim = len(shape) - if ndim > 4: - image = image.reshape((-1,) + shape[-3:]) - elif ndim == 3: + ndim = image.ndim + if ndim == 3: image = image.unsqueeze(dim=0) - elif ndim < 3: - raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.") - + elif ndim > 4: + image = image.reshape((-1,) + shape[-3:]) fp = torch.is_floating_point(image) kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device) kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1]) - output = image if fp else image.to(torch.float32) + output = image if fp else image.to(dtype=torch.float32) # padding = (left, right, top, bottom) padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] output = torch_pad(output, padding, mode="reflect") - output = conv2d(output, kernel, groups=output.shape[-3]) - - if output.dtype != dtype: - if not fp: - output.round_() - output = output.to(dtype) + output = conv2d(output, kernel, groups=shape[-3]) - if ndim > 4: - output = output.reshape(shape) - elif ndim == 3: + if ndim == 3: output = output.squeeze(dim=0) + elif ndim > 4: + output = output.reshape(shape) + + if not fp: + output = output.round_().to(dtype=dtype) return output From cdcd9c750b9ef1608dbfbf9aa6a6b995038521bf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Nov 2022 13:32:10 +0000 Subject: [PATCH 4/4] Remove unused import. --- torchvision/prototype/transforms/functional/_misc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 04b8959190f..d8bfc7cae1b 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -5,7 +5,6 @@ import torch from torch.nn.functional import conv2d, pad as torch_pad from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image