Skip to content

Commit 3d442ec

Browse files
authored
Merge branch 'optim-wip-jit-transforms-support-rotate' into optim-wip-jit-images-support
2 parents 9b7ce93 + 5d3f609 commit 3d442ec

File tree

7 files changed

+2176
-116
lines changed

7 files changed

+2176
-116
lines changed

captum/optim/_param/image/images.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class FFTImage(ImageParameterization):
142142
Parameterize an image using inverse real 2D FFT
143143
"""
144144

145+
__constants__ = ["size", "_supports_is_scripting"]
146+
145147
def __init__(
146148
self,
147149
size: Tuple[int, int] = None,
@@ -201,6 +203,9 @@ def __init__(
201203
self.register_buffer("spectrum_scale", spectrum_scale)
202204
self.fourier_coeffs = nn.Parameter(fourier_coeffs)
203205

206+
# Check & store whether or not we can use torch.jit.is_scripting()
207+
self._supports_is_scripting = torch.__version__ >= "1.6.0"
208+
204209
def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
205210
"""
206211
Computes 2D spectrum frequencies.
@@ -218,6 +223,12 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
218223
fx = self.torch_fftfreq(width)[: width // 2 + 1]
219224
return torch.sqrt((fx * fx) + (fy * fy))
220225

226+
@torch.jit.export
227+
def torch_irfftn(self, x: torch.Tensor) -> torch.Tensor:
228+
if x.dtype != torch.complex64:
229+
x = torch.view_as_complex(x)
230+
return torch.fft.irfftn(x, s=self.size) # type: ignore
231+
221232
def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
222233
"""
223234
Support older versions of PyTorch. This function ensures that the same FFT
@@ -230,26 +241,24 @@ def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
230241
"""
231242

232243
if TORCH_VERSION >= "1.7.0":
233-
import torch.fft
244+
if TORCH_VERSION < "1.8.0":
245+
global torch
246+
import torch.fft
234247

235248
def torch_rfft(x: torch.Tensor) -> torch.Tensor:
236249
return torch.view_as_real(torch.fft.rfftn(x, s=self.size))
237250

238-
def torch_irfft(x: torch.Tensor) -> torch.Tensor:
239-
if type(x) is not torch.complex64:
240-
x = torch.view_as_complex(x)
241-
return torch.fft.irfftn(x, s=self.size) # type: ignore
251+
torch_irfftn = self.torch_irfftn
242252

243253
def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
244254
return torch.fft.fftfreq(v, d)
245255

246256
else:
247-
import torch
248257

249258
def torch_rfft(x: torch.Tensor) -> torch.Tensor:
250259
return torch.rfft(x, signal_ndim=2)
251260

252-
def torch_irfft(x: torch.Tensor) -> torch.Tensor:
261+
def torch_irfftn(x: torch.Tensor) -> torch.Tensor:
253262
return torch.irfft(x, signal_ndim=2)[
254263
:, :, : self.size[0], : self.size[1]
255264
]
@@ -262,7 +271,7 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
262271
results[s:] = torch.arange(-(v // 2), 0)
263272
return results * (1.0 / (v * d))
264273

265-
return torch_rfft, torch_irfft, torch_fftfreq
274+
return torch_rfft, torch_irfftn, torch_fftfreq
266275

267276
def forward(self) -> torch.Tensor:
268277
"""
@@ -272,6 +281,9 @@ def forward(self) -> torch.Tensor:
272281

273282
scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
274283
output = self.torch_irfft(scaled_spectrum)
284+
if self._supports_is_scripting:
285+
if torch.jit.is_scripting():
286+
return output
275287
return output.refine_names("B", "C", "H", "W")
276288

277289

@@ -280,6 +292,8 @@ class PixelImage(ImageParameterization):
280292
Parameterize a simple pixel image tensor that requires no additional transforms.
281293
"""
282294

295+
__constants__ = ["_supports_is_scripting"]
296+
283297
def __init__(
284298
self,
285299
size: Tuple[int, int] = None,
@@ -311,7 +325,13 @@ def __init__(
311325
init = init.unsqueeze(0)
312326
self.image = nn.Parameter(init)
313327

328+
# Check & store whether or not we can use torch.jit.is_scripting()
329+
self._supports_is_scripting = torch.__version__ >= "1.6.0"
330+
314331
def forward(self) -> torch.Tensor:
332+
if self._supports_is_scripting:
333+
if torch.jit.is_scripting():
334+
return self.image
315335
return self.image.refine_names("B", "C", "H", "W")
316336

317337

@@ -789,7 +809,7 @@ def __init__(
789809
nn.Parameter tensor, or stacking init images.
790810
Default: 1
791811
parameterization (ImageParameterization, optional): An image
792-
parameterization class.
812+
parameterization class, or instance of an image parameterization class.
793813
Default: FFTImage
794814
squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash
795815
function to use after color recorrelation. A funtion or lambda function.
@@ -801,8 +821,14 @@ def __init__(
801821
Default: True
802822
"""
803823
super().__init__()
824+
if not isinstance(parameterization, ImageParameterization):
825+
# Verify uninitialized class is correct type
826+
assert issubclass(parameterization, ImageParameterization)
827+
else:
828+
assert isinstance(parameterization, ImageParameterization)
829+
804830
self.decorrelate = decorrelation_module
805-
if init is not None:
831+
if init is not None and not isinstance(parameterization, ImageParameterization):
806832
assert init.dim() == 3 or init.dim() == 4
807833
if decorrelate_init and self.decorrelate is not None:
808834
init = (
@@ -811,27 +837,42 @@ def __init__(
811837
else init.refine_names("C", "H", "W")
812838
)
813839
init = self.decorrelate(init, inverse=True).rename(None)
840+
814841
if squash_func is None:
842+
squash_func = self._clamp_image
815843

816-
def squash_func(x: torch.Tensor) -> torch.Tensor:
817-
return x.clamp(0, 1)
844+
self.squash_func = torch.sigmoid if squash_func is None else squash_func
845+
if not isinstance(parameterization, ImageParameterization):
846+
parameterization = parameterization(
847+
size=size, channels=channels, batch=batch, init=init
848+
)
849+
self.parameterization = parameterization
818850

819-
else:
820-
if squash_func is None:
851+
@torch.jit.export
852+
def _clamp_image(self, x: torch.Tensor) -> torch.Tensor:
853+
"""JIT supported squash function."""
854+
return x.clamp(0, 1)
855+
856+
@torch.jit.ignore
857+
def _to_image_tensor(self, x: torch.Tensor) -> torch.Tensor:
858+
"""
859+
Wrap ImageTensor in torch.jit.ignore for JIT support.
821860
822-
squash_func = torch.sigmoid
861+
Args:
823862
824-
self.squash_func = squash_func
825-
self.parameterization = parameterization(
826-
size=size, channels=channels, batch=batch, init=init
827-
)
863+
x (torch.tensor): An input tensor.
864+
865+
Returns:
866+
x (ImageTensor): An instance of ImageTensor with the input tensor.
867+
"""
868+
return ImageTensor(x)
828869

829870
def forward(self) -> torch.Tensor:
830871
image = self.parameterization()
831872
if self.decorrelate is not None:
832873
image = self.decorrelate(image)
833874
image = image.rename(None) # TODO: the world is not yet ready
834-
return ImageTensor(self.squash_func(image))
875+
return self._to_image_tensor(self.squash_func(image))
835876

836877

837878
__all__ = [

0 commit comments

Comments
 (0)