1
1
import torch
2
+ from torch .nn .functional import conv2d
2
3
from torchvision .prototype import features
3
4
from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
4
5
@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
111
112
if image .numel () == 0 or height <= 2 or width <= 2 :
112
113
return image
113
114
115
+ bound = _FT ._max_value (image .dtype )
116
+ fp = image .is_floating_point ()
114
117
shape = image .shape
115
118
116
119
if image .ndim > 4 :
@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
119
122
else :
120
123
needs_unsquash = False
121
124
122
- output = _blend (image , _FT ._blurred_degenerate_image (image ), sharpness_factor )
125
+ # The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
126
+ kernel_dtype = image .dtype if fp else torch .float32
127
+ a , b = 1.0 / 13.0 , 5.0 / 13.0
128
+ kernel = torch .tensor ([[a , a , a ], [a , b , a ], [a , a , a ]], dtype = kernel_dtype , device = image .device )
129
+ kernel = kernel .expand (num_channels , 1 , 3 , 3 )
130
+
131
+ # We copy and cast at the same time to avoid modifications on the original data
132
+ output = image .to (dtype = kernel_dtype , copy = True )
133
+ blurred_degenerate = conv2d (output , kernel , groups = num_channels )
134
+ if not fp :
135
+ # it is better to round before cast
136
+ blurred_degenerate = blurred_degenerate .round_ ()
137
+
138
+ # Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
139
+ view = output [..., 1 :- 1 , 1 :- 1 ]
140
+
141
+ # We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
142
+ # x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
143
+ view .add_ (blurred_degenerate .sub_ (view ), alpha = (1.0 - sharpness_factor ))
144
+
145
+ # The actual data of ouput have been modified by the above. We only need to clamp and cast now.
146
+ output = output .clamp_ (0 , bound )
147
+ if not fp :
148
+ output = output .to (image .dtype )
123
149
124
150
if needs_unsquash :
125
151
output = output .reshape (shape )
0 commit comments