@@ -142,6 +142,8 @@ class FFTImage(ImageParameterization):
142
142
Parameterize an image using inverse real 2D FFT
143
143
"""
144
144
145
+ __constants__ = ["size" , "_supports_is_scripting" ]
146
+
145
147
def __init__ (
146
148
self ,
147
149
size : Tuple [int , int ] = None ,
@@ -201,6 +203,9 @@ def __init__(
201
203
self .register_buffer ("spectrum_scale" , spectrum_scale )
202
204
self .fourier_coeffs = nn .Parameter (fourier_coeffs )
203
205
206
+ # Check & store whether or not we can use torch.jit.is_scripting()
207
+ self ._supports_is_scripting = torch .__version__ >= "1.6.0"
208
+
204
209
def rfft2d_freqs (self , height : int , width : int ) -> torch .Tensor :
205
210
"""
206
211
Computes 2D spectrum frequencies.
@@ -218,6 +223,12 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
218
223
fx = self .torch_fftfreq (width )[: width // 2 + 1 ]
219
224
return torch .sqrt ((fx * fx ) + (fy * fy ))
220
225
226
+ @torch .jit .export
227
+ def torch_irfftn (self , x : torch .Tensor ) -> torch .Tensor :
228
+ if x .dtype != torch .complex64 :
229
+ x = torch .view_as_complex (x )
230
+ return torch .fft .irfftn (x , s = self .size ) # type: ignore
231
+
221
232
def get_fft_funcs (self ) -> Tuple [Callable , Callable , Callable ]:
222
233
"""
223
234
Support older versions of PyTorch. This function ensures that the same FFT
@@ -230,26 +241,24 @@ def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
230
241
"""
231
242
232
243
if TORCH_VERSION >= "1.7.0" :
233
- import torch .fft
244
+ if TORCH_VERSION < "1.8.0" :
245
+ global torch
246
+ import torch .fft
234
247
235
248
def torch_rfft (x : torch .Tensor ) -> torch .Tensor :
236
249
return torch .view_as_real (torch .fft .rfftn (x , s = self .size ))
237
250
238
- def torch_irfft (x : torch .Tensor ) -> torch .Tensor :
239
- if type (x ) is not torch .complex64 :
240
- x = torch .view_as_complex (x )
241
- return torch .fft .irfftn (x , s = self .size ) # type: ignore
251
+ torch_irfftn = self .torch_irfftn
242
252
243
253
def torch_fftfreq (v : int , d : float = 1.0 ) -> torch .Tensor :
244
254
return torch .fft .fftfreq (v , d )
245
255
246
256
else :
247
- import torch
248
257
249
258
def torch_rfft (x : torch .Tensor ) -> torch .Tensor :
250
259
return torch .rfft (x , signal_ndim = 2 )
251
260
252
- def torch_irfft (x : torch .Tensor ) -> torch .Tensor :
261
+ def torch_irfftn (x : torch .Tensor ) -> torch .Tensor :
253
262
return torch .irfft (x , signal_ndim = 2 )[
254
263
:, :, : self .size [0 ], : self .size [1 ]
255
264
]
@@ -262,7 +271,7 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
262
271
results [s :] = torch .arange (- (v // 2 ), 0 )
263
272
return results * (1.0 / (v * d ))
264
273
265
- return torch_rfft , torch_irfft , torch_fftfreq
274
+ return torch_rfft , torch_irfftn , torch_fftfreq
266
275
267
276
def forward (self ) -> torch .Tensor :
268
277
"""
@@ -272,6 +281,9 @@ def forward(self) -> torch.Tensor:
272
281
273
282
scaled_spectrum = self .fourier_coeffs * self .spectrum_scale
274
283
output = self .torch_irfft (scaled_spectrum )
284
+ if self ._supports_is_scripting :
285
+ if torch .jit .is_scripting ():
286
+ return output
275
287
return output .refine_names ("B" , "C" , "H" , "W" )
276
288
277
289
@@ -280,6 +292,8 @@ class PixelImage(ImageParameterization):
280
292
Parameterize a simple pixel image tensor that requires no additional transforms.
281
293
"""
282
294
295
+ __constants__ = ["_supports_is_scripting" ]
296
+
283
297
def __init__ (
284
298
self ,
285
299
size : Tuple [int , int ] = None ,
@@ -311,7 +325,13 @@ def __init__(
311
325
init = init .unsqueeze (0 )
312
326
self .image = nn .Parameter (init )
313
327
328
+ # Check & store whether or not we can use torch.jit.is_scripting()
329
+ self ._supports_is_scripting = torch .__version__ >= "1.6.0"
330
+
314
331
def forward (self ) -> torch .Tensor :
332
+ if self ._supports_is_scripting :
333
+ if torch .jit .is_scripting ():
334
+ return self .image
315
335
return self .image .refine_names ("B" , "C" , "H" , "W" )
316
336
317
337
@@ -789,7 +809,7 @@ def __init__(
789
809
nn.Parameter tensor, or stacking init images.
790
810
Default: 1
791
811
parameterization (ImageParameterization, optional): An image
792
- parameterization class.
812
+ parameterization class, or instance of an image parameterization class .
793
813
Default: FFTImage
794
814
squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash
795
815
function to use after color recorrelation. A funtion or lambda function.
@@ -801,8 +821,14 @@ def __init__(
801
821
Default: True
802
822
"""
803
823
super ().__init__ ()
824
+ if not isinstance (parameterization , ImageParameterization ):
825
+ # Verify uninitialized class is correct type
826
+ assert issubclass (parameterization , ImageParameterization )
827
+ else :
828
+ assert isinstance (parameterization , ImageParameterization )
829
+
804
830
self .decorrelate = decorrelation_module
805
- if init is not None :
831
+ if init is not None and not isinstance ( parameterization , ImageParameterization ) :
806
832
assert init .dim () == 3 or init .dim () == 4
807
833
if decorrelate_init and self .decorrelate is not None :
808
834
init = (
@@ -811,27 +837,42 @@ def __init__(
811
837
else init .refine_names ("C" , "H" , "W" )
812
838
)
813
839
init = self .decorrelate (init , inverse = True ).rename (None )
840
+
814
841
if squash_func is None :
842
+ squash_func = self ._clamp_image
815
843
816
- def squash_func (x : torch .Tensor ) -> torch .Tensor :
817
- return x .clamp (0 , 1 )
844
+ self .squash_func = torch .sigmoid if squash_func is None else squash_func
845
+ if not isinstance (parameterization , ImageParameterization ):
846
+ parameterization = parameterization (
847
+ size = size , channels = channels , batch = batch , init = init
848
+ )
849
+ self .parameterization = parameterization
818
850
819
- else :
820
- if squash_func is None :
851
+ @torch .jit .export
852
+ def _clamp_image (self , x : torch .Tensor ) -> torch .Tensor :
853
+ """JIT supported squash function."""
854
+ return x .clamp (0 , 1 )
855
+
856
+ @torch .jit .ignore
857
+ def _to_image_tensor (self , x : torch .Tensor ) -> torch .Tensor :
858
+ """
859
+ Wrap ImageTensor in torch.jit.ignore for JIT support.
821
860
822
- squash_func = torch . sigmoid
861
+ Args:
823
862
824
- self .squash_func = squash_func
825
- self .parameterization = parameterization (
826
- size = size , channels = channels , batch = batch , init = init
827
- )
863
+ x (torch.tensor): An input tensor.
864
+
865
+ Returns:
866
+ x (ImageTensor): An instance of ImageTensor with the input tensor.
867
+ """
868
+ return ImageTensor (x )
828
869
829
870
def forward (self ) -> torch .Tensor :
830
871
image = self .parameterization ()
831
872
if self .decorrelate is not None :
832
873
image = self .decorrelate (image )
833
874
image = image .rename (None ) # TODO: the world is not yet ready
834
- return ImageTensor (self .squash_func (image ))
875
+ return self . _to_image_tensor (self .squash_func (image ))
835
876
836
877
837
878
__all__ = [
0 commit comments