Skip to content

Commit 7a7ab7e

Browse files
authored
[prototype] Speed up adjust_sharpness_image_tensor (#6930)
* Speed up `adjust_sharpness_image_tensor` * Add a comment
1 parent bf58902 commit 7a7ab7e

File tree

1 file changed

+27
-1
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+27
-1
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch.nn.functional import conv2d
23
from torchvision.prototype import features
34
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
45

@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
111112
if image.numel() == 0 or height <= 2 or width <= 2:
112113
return image
113114

115+
bound = _FT._max_value(image.dtype)
116+
fp = image.is_floating_point()
114117
shape = image.shape
115118

116119
if image.ndim > 4:
@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
119122
else:
120123
needs_unsquash = False
121124

122-
output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
125+
# The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
126+
kernel_dtype = image.dtype if fp else torch.float32
127+
a, b = 1.0 / 13.0, 5.0 / 13.0
128+
kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device)
129+
kernel = kernel.expand(num_channels, 1, 3, 3)
130+
131+
# We copy and cast at the same time to avoid modifications on the original data
132+
output = image.to(dtype=kernel_dtype, copy=True)
133+
blurred_degenerate = conv2d(output, kernel, groups=num_channels)
134+
if not fp:
135+
# it is better to round before cast
136+
blurred_degenerate = blurred_degenerate.round_()
137+
138+
# Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
139+
view = output[..., 1:-1, 1:-1]
140+
141+
# We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
142+
# x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
143+
view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor))
144+
145+
# The actual data of ouput have been modified by the above. We only need to clamp and cast now.
146+
output = output.clamp_(0, bound)
147+
if not fp:
148+
output = output.to(image.dtype)
123149

124150
if needs_unsquash:
125151
output = output.reshape(shape)

0 commit comments

Comments
 (0)