Skip to content

Optim-wip: Add new StackImage parameterization & JIT support for SharedImage #833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 17, 2022
192 changes: 178 additions & 14 deletions captum/optim/_param/image/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,38 @@ def forward(self) -> torch.Tensor:
return torch.stack(A).refine_names("B", "C", "H", "W")


class SimpleTensorParameterization(ImageParameterization):
"""
Parameterize a simple tensor with or without it requiring grad.
Compared to PixelImage, this parameterization has no specific shape requirements
and does not wrap inputs in nn.Parameter.

This parameterization can for example be combined with StackImage for batch
dimensions that both require and don't require gradients.

This parameterization can also be combined with nn.ModuleList as workaround for
TorchScript / JIT not supporting nn.ParameterList. SharedImage uses this module
internally for this purpose.
"""

def __init__(self, tensor: torch.Tensor = None) -> None:
"""
Args:

tensor (torch.tensor): The tensor to return everytime this module is called.
"""
super().__init__()
assert isinstance(tensor, torch.Tensor)
self.tensor = tensor

def forward(self) -> torch.Tensor:
"""
Returns:
tensor (torch.Tensor): The tensor stored during initialization.
"""
return self.tensor


class SharedImage(ImageParameterization):
"""
Share some image parameters across the batch to increase spatial alignment,
Expand All @@ -429,6 +461,8 @@ class SharedImage(ImageParameterization):
https://distill.pub/2018/differentiable-parameterizations/
"""

__constants__ = ["offset"]

def __init__(
self,
shapes: Union[Tuple[Tuple[int]], Tuple[int]] = None,
Expand All @@ -454,8 +488,11 @@ def __init__(
assert len(shape) >= 2 and len(shape) <= 4
shape = ([1] * (4 - len(shape))) + list(shape)
batch, channels, height, width = shape
A.append(torch.nn.Parameter(torch.randn([batch, channels, height, width])))
self.shared_init = torch.nn.ParameterList(A)
shape_param = torch.nn.Parameter(
torch.randn([batch, channels, height, width])
)
A.append(SimpleTensorParameterization(shape_param))
self.shared_init = torch.nn.ModuleList(A)
self.parameterization = parameterization
self.offset = self._get_offset(offset, len(A)) if offset is not None else None

Expand Down Expand Up @@ -484,6 +521,7 @@ def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]
assert all([all([type(o) is int for o in v]) for v in offset])
return offset

@torch.jit.ignore
def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Apply list of offsets to list of tensors.
Expand Down Expand Up @@ -517,6 +555,63 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
A.append(x)
return A

def _interpolate_bilinear(
self,
x: torch.Tensor,
size: Tuple[int, int],
) -> torch.Tensor:
"""
Perform interpolation without any warnings.

Args:

x (torch.Tensor): The NCHW tensor to resize.
size (tuple of int): The desired output size to resize the input
to, with a format of: [height, width].

Returns:
x (torch.Tensor): A resized NCHW tensor.
"""
assert x.dim() == 4
assert len(size) == 2

x = F.interpolate(
x,
size=size,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
return x

def _interpolate_trilinear(
self,
x: torch.Tensor,
size: Tuple[int, int, int],
) -> torch.Tensor:
"""
Perform interpolation without any warnings.

Args:

x (torch.Tensor): The NCHW tensor to resize.
size (tuple of int): The desired output size to resize the input
to, with a format of: [channels, height, width].

Returns:
x (torch.Tensor): A resized NCHW tensor.
"""
x = x.unsqueeze(0)
assert x.dim() == 5
x = F.interpolate(
x,
size=size,
mode="trilinear",
align_corners=False,
recompute_scale_factor=False,
)
return x.squeeze(0)

def _interpolate_tensor(
self, x: torch.Tensor, batch: int, channels: int, height: int, width: int
) -> torch.Tensor:
Expand All @@ -537,29 +632,26 @@ def _interpolate_tensor(
"""

if x.size(1) == channels:
mode = "bilinear"
size = (height, width)
x = self._interpolate_bilinear(x, size=size)
else:
mode = "trilinear"
x = x.unsqueeze(0)
size = (channels, height, width)
x = F.interpolate(x, size=size, mode=mode)
x = x.squeeze(0) if len(size) == 3 else x
x = self._interpolate_trilinear(x, size=size)
if x.size(0) != batch:
x = x.permute(1, 0, 2, 3)
x = F.interpolate(
x.unsqueeze(0),
size=(batch, x.size(2), x.size(3)),
mode="trilinear",
).squeeze(0)
x = self._interpolate_trilinear(x, size=(batch, x.size(2), x.size(3)))
x = x.permute(1, 0, 2, 3)
return x

def forward(self) -> torch.Tensor:
"""
Returns:
output (torch.Tensor): An NCHW image parameterization output.
"""
image = self.parameterization()
x = [
self._interpolate_tensor(
shared_tensor,
shared_tensor(),
image.size(0),
image.size(1),
image.size(2),
Expand All @@ -569,7 +661,78 @@ def forward(self) -> torch.Tensor:
]
if self.offset is not None:
x = self._apply_offset(x)
return (image + sum(x)).refine_names("B", "C", "H", "W")
output = image + torch.cat(x, 0).sum(0, keepdim=True)

if torch.jit.is_scripting():
return output
return output.refine_names("B", "C", "H", "W")


class StackImage(ImageParameterization):
"""
Stack multiple NCHW image parameterizations along their batch dimensions.
"""

__constants__ = ["dim", "output_device"]

def __init__(
self,
parameterizations: List[Union[ImageParameterization, torch.Tensor]],
dim: int = 0,
output_device: Optional[torch.device] = None,
) -> None:
"""
Args:

parameterizations (list of ImageParameterization and torch.Tensor): A list
of image parameterizations to stack across their batch dimensions.
dim (int, optional): Optionally specify the dim to concatinate
parameterization outputs on. Default is set to the batch dimension.
Default: 0
output_device (torch.device, optional): If the parameterizations are on
different devices, then their outputs will be moved to the device
specified by this variable. Default is set to None with the expectation
that all parameterizations are on the same device.
Default: None
"""
super().__init__()
assert len(parameterizations) > 0
assert isinstance(parameterizations, (list, tuple))
assert all(
[
isinstance(param, (ImageParameterization, torch.Tensor))
for param in parameterizations
]
)
parameterizations = [
SimpleTensorParameterization(p) if isinstance(p, torch.Tensor) else p
for p in parameterizations
]
self.parameterizations = torch.nn.ModuleList(parameterizations)
self.dim = dim
self.output_device = output_device

def forward(self) -> torch.Tensor:
"""
Returns:
image (torch.Tensor): A set of NCHW image parameterization outputs stacked
along the batch dimension.
"""
P = []
for image_param in self.parameterizations:
img = image_param()
if self.output_device is not None:
img = img.to(self.output_device, dtype=img.dtype)
P.append(img)

assert P[0].dim() == 4
assert all([im.shape == P[0].shape for im in P])
assert all([im.device == P[0].device for im in P])

image = torch.cat(P, dim=self.dim)
if torch.jit.is_scripting():
return image
return image.refine_names("B", "C", "H", "W")


class NaturalImage(ImageParameterization):
Expand Down Expand Up @@ -683,5 +846,6 @@ def forward(self) -> torch.Tensor:
"PixelImage",
"LaplacianImage",
"SharedImage",
"StackImage",
"NaturalImage",
]
Loading