diff --git a/captum/optim/_core/loss.py b/captum/optim/_core/loss.py index 66bb4c40c2..0cc6033fd4 100644 --- a/captum/optim/_core/loss.py +++ b/captum/optim/_core/loss.py @@ -5,7 +5,11 @@ import torch import torch.nn as nn -from captum.optim._utils.image.common import _dot_cossim, get_neuron_pos +from captum.optim._utils.image.common import ( + _create_new_vector, + _dot_cossim, + get_neuron_pos, +) from captum.optim._utils.typing import ModuleOutputMapping @@ -837,6 +841,221 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: return activations +@loss_wrapper +class L2Mean(BaseLoss): + """ + Simple L2Loss penalty where the mean is used instead of the square root of the + sum. + + Used for CLIP models in https://distill.pub/2021/multimodal-neurons/ as per the + supplementary code: + https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py + """ + + def __init__( + self, + target: torch.nn.Module, + channel_index: Optional[int] = None, + constant: float = 0.5, + batch_index: Optional[Union[int, List[int]]] = None, + ) -> None: + """ + Args: + + target (nn.Module): A target layer, transform, or image parameterization + instance. + channel_index (int, optional): Optionally only target a specific channel. + If set to ``None``, all channels with be used. + Default: ``None`` + constant (float, optional): Constant value to deduct from the activations. + Default: ``0.5`` + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set + to ``None``, defaults to all activations in the batch. Index ranges + should be in the format of: [start, end]. + Default: ``None`` + """ + BaseLoss.__init__(self, target, batch_index) + self.constant = constant + self.channel_index = channel_index + + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: + activations = targets_to_values[self.target][ + self.batch_index[0] : self.batch_index[1] + ] + if self.channel_index is not None: + activations = activations[:, self.channel_index : self.channel_index + 1] + return ((activations - self.constant) ** 2).mean() + + +@loss_wrapper +class VectorLoss(BaseLoss): + """ + This objective is useful for optimizing towards channel directions. This can + helpful for visualizing models like OpenAI's CLIP. + + This loss objective is similar to the Direction objective, except it computes the + matrix product of the activations and vector, rather than the cosine similarity. + In addition to optimizing towards channel directions, this objective can also + perform a similar role to the ChannelActivation objective by using one-hot 1D + vectors. + + See here for more details: + https://distill.pub/2021/multimodal-neurons/ + https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py + """ + + def __init__( + self, + target: torch.nn.Module, + vec: torch.Tensor, + activation_fn: Optional[Callable] = torch.nn.functional.relu, + move_channel_dim_to_final_dim: bool = True, + batch_index: Optional[Union[int, List[int]]] = None, + ) -> None: + """ + Args: + + target (nn.Module): A target layer instance. + vec (torch.Tensor): A 1D channel vector with the same size as the + channel / feature dimension of the target layer instance. + activation_fn (callable, optional): An optional activation function to + apply to the activations before computing the matrix product. If set + to ``None``, then no activation function will be used. + Default: ``torch.nn.functional.relu`` + move_channel_dim_to_final_dim (bool, optional): Whether or not to move the + channel dimension to the last dimension before computing the matrix + product. Set to ``False`` if the using the channels last format. + Default: ``True`` + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set + to ``None``, defaults to all activations in the batch. Index ranges + should be in the format of: [start, end]. + Default: ``None`` + """ + BaseLoss.__init__(self, target, batch_index) + assert vec.dim() == 1 + self.vec = vec + self.activation_fn = activation_fn + self.move_channel_dim_to_final_dim = move_channel_dim_to_final_dim + + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: + activations = targets_to_values[self.target] + activations = activations[self.batch_index[0] : self.batch_index[1]] + return _create_new_vector( + activations, + vec=self.vec, + activation_fn=self.activation_fn, + move_channel_dim_to_final_dim=self.move_channel_dim_to_final_dim, + ).mean() + + +@loss_wrapper +class FacetLoss(BaseLoss): + """ + The Facet loss objective used for Faceted Feature Visualization as described in: + https://distill.pub/2021/multimodal-neurons/#faceted-feature-visualization + https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py + + The FacetLoss objective allows us to steer feature visualization towards a + particular theme / concept. This is done by using the weights from linear probes + trained on the lower layers of a model to discriminate between a certain theme or + concept and generic natural images. + """ + + def __init__( + self, + vec: torch.Tensor, + ultimate_target: torch.nn.Module, + layer_target: Union[torch.nn.Module, List[torch.nn.Module]], + facet_weights: torch.Tensor, + strength: Optional[Union[float, List[float]]] = None, + batch_index: Optional[Union[int, List[int]]] = None, + ) -> None: + """ + Args: + + vec (torch.Tensor): A 1D channel vector with the same size as the + channel / feature dimension of ultimate_target. + ultimate_target (nn.Module): The main target layer that we are + visualizing targets from. This is normally the penultimate layer of + the model. + layer_target (nn.Module): A layer that we have facet_weights for. This + target layer should be below the ``ultimate_target`` layer in the + model. + facet_weights (torch.Tensor): Weighting that steers the objective + towards a particular theme or concept. These weight values should + come from linear probes trained on ``layer_target``. + strength (float, list of float, optional): A single float or list of floats + to use for batch dimension weighting. If using a single value, then it + will be applied to all batch dimensions equally. Otherwise a list of + floats with a shape of: [start, end] should be used for + :func:`torch.linspace` to calculate the step values in between. Default + is set to ``None`` for no weighting. + Default: ``None`` + batch_index (int or list of int, optional): The index or index range of + activations to optimize if optimizing a batch of activations. If set + to ``None``, defaults to all activations in the batch. Index ranges + should be in the format of: [start, end]. + Default: ``None`` + """ + BaseLoss.__init__(self, [ultimate_target, layer_target], batch_index) + self.ultimate_target = ultimate_target + self.layer_target = layer_target + assert vec.dim() == 1 + self.vec = vec + if isinstance(strength, (tuple, list)): + assert len(strength) == 2 + self.strength = strength + assert facet_weights.dim() == 4 or facet_weights.dim() == 2 + self.facet_weights = facet_weights + + def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: + activations_ultimate = targets_to_values[self.ultimate_target] + activations_ultimate = activations_ultimate[ + self.batch_index[0] : self.batch_index[1] + ] + new_vec = _create_new_vector(activations_ultimate, self.vec) + target_activations = targets_to_values[self.layer_target] + + layer_grad = torch.autograd.grad( + outputs=new_vec, + inputs=target_activations, + grad_outputs=torch.ones_like(new_vec), + retain_graph=True, + )[0].detach()[self.batch_index[0] : self.batch_index[1]] + layer = target_activations[self.batch_index[0] : self.batch_index[1]] + + flat_attr = layer * torch.nn.functional.relu(layer_grad) + if self.facet_weights.dim() == 2 and flat_attr.dim() == 4: + flat_attr = torch.sum(flat_attr, dim=(2, 3)) + + if self.strength: + if isinstance(self.strength, (tuple, list)): + strength_t = torch.linspace( + self.strength[0], + self.strength[1], + steps=flat_attr.shape[0], + device=flat_attr.device, + ).reshape(flat_attr.shape[0], *[1] * (flat_attr.dim() - 1)) + else: + strength_t = self.strength + flat_attr = strength_t * flat_attr + + if ( + self.facet_weights.dim() == 4 + and layer.dim() == 4 + and self.facet_weights.shape[2:] != layer.shape[2:] + ): + facet_weights = torch.nn.functional.interpolate( + self.facet_weights, size=layer.shape[2:] + ) + else: + facet_weights = self.facet_weights + + return torch.sum(flat_attr * facet_weights) + + def sum_loss_list( loss_list: List, to_scalar_fn: Callable[[torch.Tensor], torch.Tensor] = torch.mean, @@ -908,6 +1127,9 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor: "AngledNeuronDirection", "TensorDirection", "ActivationWeights", + "L2Mean", + "VectorLoss", + "FacetLoss", "sum_loss_list", "default_loss_summarize", ] diff --git a/captum/optim/_utils/image/common.py b/captum/optim/_utils/image/common.py index f1cdc5f477..9e7553b251 100644 --- a/captum/optim/_utils/image/common.py +++ b/captum/optim/_utils/image/common.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -363,3 +363,56 @@ def hex2base10(x: str) -> float: * ((1 - (-x - 0.5) * 2) * color_list[1] + (-x - 0.5) * 2 * color_list[0]) ).permute(2, 0, 1) return color_tensor + + +def _create_new_vector( + x: torch.Tensor, + vec: torch.Tensor, + activation_fn: Optional[ + Callable[[torch.Tensor], torch.Tensor] + ] = torch.nn.functional.relu, + move_channel_dim_to_final_dim: bool = True, +) -> torch.Tensor: + """ + Create a vector using a given set of activations and another vector. + This function is intended for use in CLIP related loss objectives. + + https://distill.pub/2021/multimodal-neurons/ + https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py + The einsum equation: "ijkl,j->ikl", used by the paper's associated code is the + same thing as: "[..., C] @ vec", where vec has a shape of 'C'. + + Args: + + x (torch.Tensor): A set of 2d or 4d activations. + vec (torch.Tensor): A 1D direction vector to use, with a compatible shape for + computing the matrix product of the activations. See torch.matmul for + See torch.matmul for more details on compatible shapes: + https://pytorch.org/docs/stable/generated/torch.matmul.html + By default, ``vec`` is expected to share the same size as the channel or + feature dimension of the activations. + activation_fn (Callable, optional): An optional activation function to + apply to the activations before computing the matrix product. If set + to None, then no activation function will be used. + Default: ``torch.nn.functional.relu`` + move_channel_dim_to_final_dim (bool, optional): Whether or not to move the + channel dimension to the last dimension before computing the matrix + product. + Default: ``True`` + + Returns + x (torch.Tensor): A vector created from the input activations and the + stored vector. + """ + assert x.device == vec.device + assert x.dim() > 1 and vec.dim() == 1 + if activation_fn: + x = activation_fn(x) + if x.dim() > 2: + if move_channel_dim_to_final_dim: + permute_vals = [0] + list(range(x.dim()))[2:] + [1] + x = x.permute(*permute_vals) + mean_vals = list(range(1, x.dim() - 1)) + return torch.mean(x @ vec, mean_vals) + else: + return (x @ vec)[:, None] diff --git a/tests/optim/core/test_loss.py b/tests/optim/core/test_loss.py index 49c35ed9d4..ee8e34a033 100644 --- a/tests/optim/core/test_loss.py +++ b/tests/optim/core/test_loss.py @@ -197,6 +197,225 @@ def test_activation_weights_1(self) -> None: ) +class TestL2Mean(BaseTest): + def test_l2mean_init(self) -> None: + model = torch.nn.Identity() + loss = opt_loss.L2Mean(model) + self.assertEqual(loss.constant, 0.5) + self.assertIsNone(loss.channel_index) + + def test_l2mean_constant(self) -> None: + model = BasicModel_ConvNet_Optim() + constant = 0.5 + loss = opt_loss.L2Mean(model.layer, constant=constant) + output = get_loss_value(model, loss) + + expected = (CHANNEL_ACTIVATION_0_LOSS - constant) ** 2 + self.assertAlmostEqual(output, expected, places=6) + + def test_l2mean_channel_index(self) -> None: + model = BasicModel_ConvNet_Optim() + constant = 0.0 + loss = opt_loss.L2Mean(model.layer, channel_index=0, constant=constant) + output = get_loss_value(model, loss) + + expected = (CHANNEL_ACTIVATION_0_LOSS - constant) ** 2 + self.assertAlmostEqual(output, expected, places=6) + + def test_l2mean_batch_index(self) -> None: + raise unittest.SkipTest("Remove after PR merged") + model = torch.nn.Identity() + batch_index = 1 + loss = opt_loss.L2Mean(model, batch_index=batch_index) + + model_input = torch.arange(0, 5 * 4 * 5 * 5).view(5, 4, 5, 5).float() + output = get_loss_value(model, loss, model_input) + self.assertEqual(output.item(), 23034.25) + + +class TestVectorLoss(BaseTest): + def test_vectorloss_init(self) -> None: + model = torch.nn.Identity() + vec = torch.tensor([0, 1]).float() + loss = opt_loss.VectorLoss(model, vec=vec) + assertTensorAlmostEqual(self, loss.vec, vec, delta=0.0) + self.assertTrue(loss.move_channel_dim_to_final_dim) + self.assertEqual(loss.activation_fn, torch.nn.functional.relu) + + def test_vectorloss_single_channel(self) -> None: + model = BasicModel_ConvNet_Optim() + vec = torch.tensor([0, 1]).float() + loss = opt_loss.VectorLoss(model.layer, vec=vec) + output = get_loss_value(model, loss, input_shape=[1, 3, 6, 6]) + self.assertAlmostEqual(output, CHANNEL_ACTIVATION_1_LOSS, places=6) + + def test_vectorloss_multiple_channels(self) -> None: + model = BasicModel_ConvNet_Optim() + vec = torch.tensor([1, 1]).float() + loss = opt_loss.VectorLoss(model.layer, vec=vec) + output = get_loss_value(model, loss, input_shape=[1, 3, 6, 6]) + self.assertAlmostEqual(output, CHANNEL_ACTIVATION_1_LOSS * 2, places=6) + + def test_vectorloss_batch_index(self) -> None: + raise unittest.SkipTest("Remove after PR merged") + model = torch.nn.Identity() + batch_index = 1 + vec = torch.tensor([0, 1, 0, 0]).float() + loss = opt_loss.VectorLoss(model, vec=vec, batch_index=batch_index) + + model_input = torch.arange(0, 5 * 4 * 5 * 5).view(5, 4, 5, 5).float() + output = get_loss_value(model, loss, model_input) + self.assertEqual(output.item(), 137.0) + + +class TestFacetLoss(BaseTest): + def test_facetloss_init(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + vec = torch.tensor([0, 1, 0]).float() + facet_weights = torch.ones([1, 2, 1, 1]) * 1.5 + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0], + vec=vec, + facet_weights=facet_weights, + ) + assertTensorAlmostEqual(self, loss.vec, vec, delta=0.0) + assertTensorAlmostEqual(self, loss.facet_weights, facet_weights, delta=0.0) + + def test_facetloss_single_channel(self) -> None: + layer = torch.nn.Conv2d(2, 3, 1, bias=True) + layer.weight.data.fill_(0.1) # type: ignore + layer.bias.data.fill_(1) # type: ignore + model = torch.nn.Sequential(BasicModel_ConvNet_Optim(), layer) + + vec = torch.tensor([0, 1, 0]).float() + facet_weights = torch.ones([1, 2, 6, 6]) * 1.5 + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0].layer, + vec=vec, + facet_weights=facet_weights, + ) + output = get_loss_value(model, loss, input_shape=[1, 3, 6, 6]) + expected = (CHANNEL_ACTIVATION_0_LOSS * 2) * 1.5 + self.assertAlmostEqual(output, expected / 10.0, places=6) + + def test_facetloss_multi_channel(self) -> None: + layer = torch.nn.Conv2d(2, 3, 1, bias=True) + layer.weight.data.fill_(0.1) # type: ignore + layer.bias.data.fill_(1) # type: ignore + + model = torch.nn.Sequential(BasicModel_ConvNet_Optim(), layer) + + vec = torch.tensor([1, 1, 1]).float() + facet_weights = torch.ones([1, 2, 6, 6]) * 2.0 + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0].layer, + vec=vec, + facet_weights=facet_weights, + ) + output = get_loss_value(model, loss, input_shape=[1, 3, 6, 6]) + self.assertAlmostEqual(output, 1.560000, places=6) + + def test_facetloss_strength(self) -> None: + layer = torch.nn.Conv2d(2, 3, 1, bias=True) + layer.weight.data.fill_(0.1) # type: ignore + layer.bias.data.fill_(1) # type: ignore + model = torch.nn.Sequential(BasicModel_ConvNet_Optim(), layer) + + vec = torch.tensor([0, 1, 0]).float() + facet_weights = torch.ones([1, 2, 6, 6]) * 1.5 + strength = 0.5 + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0].layer, + vec=vec, + facet_weights=facet_weights, + strength=strength, + ) + self.assertEqual(loss.strength, strength) + output = get_loss_value(model, loss, input_shape=[1, 3, 6, 6]) + self.assertAlmostEqual(output, 0.1950000, places=6) + + def test_facetloss_strength_batch(self) -> None: + layer = torch.nn.Conv2d(2, 3, 1, bias=True) + layer.weight.data.fill_(0.1) # type: ignore + layer.bias.data.fill_(1) # type: ignore + model = torch.nn.Sequential(BasicModel_ConvNet_Optim(), layer) + + vec = torch.tensor([0, 1, 0]).float() + facet_weights = torch.ones([1, 2, 6, 6]) * 1.5 + strength = [0.1, 5.05] + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0].layer, + vec=vec, + facet_weights=facet_weights, + strength=strength, + ) + self.assertEqual(loss.strength, strength) + output = get_loss_value(model, loss, input_shape=[4, 3, 6, 6]) + self.assertAlmostEqual(output, 4.017000198364258, places=6) + + def test_facetloss_2d_weights(self) -> None: + layer = torch.nn.Conv2d(2, 3, 1, bias=True) + layer.weight.data.fill_(0.1) # type: ignore + layer.bias.data.fill_(1) # type: ignore + model = torch.nn.Sequential(BasicModel_ConvNet_Optim(), layer) + + vec = torch.tensor([0, 1, 0]).float() + facet_weights = torch.ones([1, 2]) * 1.5 + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0].layer, + vec=vec, + facet_weights=facet_weights, + ) + output = get_loss_value(model, loss, input_shape=[1, 3, 6, 6]) + expected = (CHANNEL_ACTIVATION_0_LOSS * 2) * 1.5 + self.assertAlmostEqual(output, expected / 10.0, places=6) + + def test_facetloss_batch_index(self) -> None: + raise unittest.SkipTest("Remove after PR merged") + batch_index = 1 + layer = torch.nn.Conv2d(2, 3, 1, bias=True) + layer.weight.data.fill_(0.1) # type: ignore + layer.bias.data.fill_(1) # type: ignore + model = torch.nn.Sequential(BasicModel_ConvNet_Optim(), layer) + + vec = torch.tensor([0, 1, 0]).float() + facet_weights = torch.ones([1, 2, 5, 5]) * 1.5 + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0].layer, + vec=vec, + facet_weights=facet_weights, + batch_index=batch_index, + ) + model_input = torch.arange(0, 5 * 3 * 5 * 5).view(5, 3, 5, 5).float() + output = get_loss_value(model, loss, model_input) + self.assertAlmostEqual(output.item(), 10.38000202178955, places=5) + + def test_facetloss_resize_4d(self) -> None: + layer = torch.nn.Conv2d(2, 3, 1, bias=True) + layer.weight.data.fill_(0.1) # type: ignore + layer.bias.data.fill_(1) # type: ignore + + model = torch.nn.Sequential(BasicModel_ConvNet_Optim(), layer) + + vec = torch.tensor([1, 1, 1]).float() + facet_weights = torch.ones([1, 2, 12, 12]) * 2.0 + loss = opt_loss.FacetLoss( + ultimate_target=model[1], + layer_target=model[0].layer, + vec=vec, + facet_weights=facet_weights, + ) + output = get_loss_value(model, loss, input_shape=[1, 3, 6, 6]) + self.assertAlmostEqual(output, 1.560000, places=6) + + class TestCompositeLoss(BaseTest): def test_negative(self) -> None: model = BasicModel_ConvNet_Optim() diff --git a/tests/optim/utils/image/test_common.py b/tests/optim/utils/image/test_common.py index ef484c7135..09e1a7355c 100644 --- a/tests/optim/utils/image/test_common.py +++ b/tests/optim/utils/image/test_common.py @@ -516,3 +516,50 @@ def test_make_grid_image_single_tensor_pad_value_jit_module(self) -> None: ) self.assertEqual(list(expected_output.shape), [1, 1, 7, 7]) assertTensorAlmostEqual(self, test_output, expected_output, 0) + + +class TestCreateNewVector(BaseTest): + def test_create_new_vector_one_hot(self) -> None: + x = torch.arange(0, 1 * 3 * 5 * 5).view(1, 3, 5, 5).float() + vec = torch.tensor([0, 1, 0]).float() + out = common._create_new_vector(x, vec) + self.assertEqual(out.item(), 37.0) + + def test_create_new_vector_one_hot_batch(self) -> None: + x = torch.arange(0, 4 * 3 * 5 * 5).view(4, 3, 5, 5).float() + vec = torch.tensor([0, 1, 0]).float() + out = common._create_new_vector(x, vec) + self.assertEqual(out.tolist(), [37.0, 112.0, 187.0, 262.0]) + + def test_create_new_vector(self) -> None: + x = torch.arange(0, 1 * 3 * 5 * 5).view(1, 3, 5, 5).float() + vec = torch.tensor([1, 1, 1]).float() + out = common._create_new_vector(x, vec) + self.assertEqual(out.item(), 111.0) + + def test_create_new_vector_activation_fn(self) -> None: + x = torch.arange(0, 1 * 3 * 5 * 5).view(1, 3, 5, 5).float() + x = x - x.mean() + vec = torch.tensor([1, 0, 1]).float() + out = common._create_new_vector(x, vec, activation_fn=torch.nn.functional.relu) + self.assertEqual(out.item(), 25.0) + + def test_create_new_vector_no_activation_fn(self) -> None: + x = torch.arange(0, 1 * 3 * 5 * 5).view(1, 3, 5, 5).float() + x = x - x.mean() + vec = torch.tensor([1, 1, 1]).float() + out = common._create_new_vector(x, vec, activation_fn=None) + self.assertEqual(out.item(), 0.0) + + def test_create_new_vector_channels_last(self) -> None: + x = torch.arange(0, 4 * 5 * 5 * 3).view(4, 5, 5, 3).float() + vec = torch.tensor([0, 1, 0]).float() + out = common._create_new_vector(x, vec, move_channel_dim_to_final_dim=False) + self.assertEqual(out.tolist(), [37.0, 112.0, 187.0, 262.0]) + + def test_create_new_vector_dim_2(self) -> None: + x = torch.arange(0, 1 * 3).view(1, 3).float() + vec = torch.tensor([0, 1, 0]).float() + out = common._create_new_vector(x, vec) + self.assertEqual(list(out.shape), [1, 1]) + self.assertEqual(out.item(), 1.0)