Skip to content

Commit 05a6dd4

Browse files
jjunchofacebook-github-bot
authored andcommitted
Add capability to pass grad_kwargs for grad_cam, internal_influence, layer_conductance, layer_deep_lift, layer_gradient_shap, and neuron_conductance (#1294)
Summary: Pull Request resolved: #1294 Extension of D57756842, where the torch.autograd.grad arguments can be passed into the following classes LayerGradCam InternalInfluence LayerConductance LayerDeepLift LayerGradientShap NeuronConductance Reviewed By: cyrjano Differential Revision: D58208128 fbshipit-source-id: ec8449e5c51d000c70c4b691858b96a0963728ff
1 parent 03340ec commit 05a6dd4

13 files changed

+195
-12
lines changed

captum/attr/_core/layer/grad_cam.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Any, Callable, List, Tuple, Union
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
55
import torch.nn.functional as F
@@ -83,6 +83,7 @@ def attribute(
8383
attribute_to_layer_input: bool = False,
8484
relu_attributions: bool = False,
8585
attr_dim_summation: bool = True,
86+
grad_kwargs: Optional[Dict[str, Any]] = None,
8687
) -> Union[Tensor, Tuple[Tensor, ...]]:
8788
r"""
8889
Args:
@@ -154,6 +155,9 @@ def attribute(
154155
sum attributions along dimension 1 (usually channel).
155156
The default (True) means to sum along dimension 1.
156157
Default: True
158+
grad_kwargs (Dict[str, Any], optional): Additional keyword
159+
arguments for torch.autograd.grad.
160+
Default: None
157161
158162
Returns:
159163
*Tensor* or *tuple[Tensor, ...]* of **attributions**:
@@ -200,6 +204,7 @@ def attribute(
200204
additional_forward_args,
201205
device_ids=self.device_ids,
202206
attribute_to_layer_input=attribute_to_layer_input,
207+
grad_kwargs=grad_kwargs,
203208
)
204209

205210
summed_grads = tuple(

captum/attr/_core/layer/internal_influence.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Any, Callable, List, Tuple, Union
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
55
from captum._utils.common import (
@@ -74,6 +74,7 @@ def attribute(
7474
method: str = "gausslegendre",
7575
internal_batch_size: Union[None, int] = None,
7676
attribute_to_layer_input: bool = False,
77+
grad_kwargs: Optional[Dict[str, Any]] = None,
7778
) -> Union[Tensor, Tuple[Tensor, ...]]:
7879
r"""
7980
Args:
@@ -185,6 +186,9 @@ def attribute(
185186
attribute to the input or output, is a single tensor.
186187
Support for multiple tensors will be added later.
187188
Default: False
189+
grad_kwargs (Dict[str, Any], optional): Additional keyword
190+
arguments for torch.autograd.grad.
191+
Default: None
188192
189193
Returns:
190194
*Tensor* or *tuple[Tensor, ...]* of **attributions**:
@@ -236,6 +240,7 @@ def attribute(
236240
n_steps=n_steps,
237241
method=method,
238242
attribute_to_layer_input=attribute_to_layer_input,
243+
grad_kwargs=grad_kwargs,
239244
)
240245

241246
return attrs
@@ -250,6 +255,7 @@ def _attribute(
250255
method: str = "gausslegendre",
251256
attribute_to_layer_input: bool = False,
252257
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
258+
grad_kwargs: Optional[Dict[str, Any]] = None,
253259
) -> Union[Tensor, Tuple[Tensor, ...]]:
254260
if step_sizes_and_alphas is None:
255261
# retrieve step size and scaling factor for specified approximation method
@@ -290,6 +296,7 @@ def _attribute(
290296
additional_forward_args=input_additional_args,
291297
device_ids=self.device_ids,
292298
attribute_to_layer_input=attribute_to_layer_input,
299+
grad_kwargs=grad_kwargs,
293300
)
294301
# flattening grads so that we can multiply it with step-size
295302
# calling contiguous to avoid `memory whole` problems

captum/attr/_core/layer/layer_conductance.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
import typing
3-
from typing import Any, Callable, List, Tuple, Union
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
44

55
import torch
66
from captum._utils.common import (
@@ -82,6 +82,7 @@ def attribute(
8282
*,
8383
return_convergence_delta: Literal[True],
8484
attribute_to_layer_input: bool = False,
85+
grad_kwargs: Optional[Dict[str, Any]] = None,
8586
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
8687

8788
@typing.overload
@@ -96,6 +97,7 @@ def attribute(
9697
internal_batch_size: Union[None, int] = None,
9798
return_convergence_delta: Literal[False] = False,
9899
attribute_to_layer_input: bool = False,
100+
grad_kwargs: Optional[Dict[str, Any]] = None,
99101
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
100102

101103
@log_usage()
@@ -112,6 +114,7 @@ def attribute(
112114
internal_batch_size: Union[None, int] = None,
113115
return_convergence_delta: bool = False,
114116
attribute_to_layer_input: bool = False,
117+
grad_kwargs: Optional[Dict[str, Any]] = None,
115118
) -> Union[
116119
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
117120
]:
@@ -230,6 +233,9 @@ def attribute(
230233
attribute to the input or output, is a single tensor.
231234
Support for multiple tensors will be added later.
232235
Default: False
236+
grad_kwargs (Dict[str, Any], optional): Additional keyword
237+
arguments for torch.autograd.grad.
238+
Default: None
233239
234240
Returns:
235241
**attributions** or 2-element tuple of **attributions**, **delta**:
@@ -322,6 +328,7 @@ def _attribute(
322328
method: str = "gausslegendre",
323329
attribute_to_layer_input: bool = False,
324330
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
331+
grad_kwargs: Optional[Dict[str, Any]] = None,
325332
) -> Union[Tensor, Tuple[Tensor, ...]]:
326333
num_examples = inputs[0].shape[0]
327334
if step_sizes_and_alphas is None:
@@ -366,6 +373,7 @@ def _attribute(
366373
target_ind=expanded_target,
367374
device_ids=self.device_ids,
368375
attribute_to_layer_input=attribute_to_layer_input,
376+
grad_kwargs=grad_kwargs,
369377
)
370378

371379
# Compute differences between consecutive evaluations of layer_eval.

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
import typing
3-
from typing import Any, Callable, cast, Sequence, Tuple, Union
3+
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, Union
44

55
import torch
66
from captum._utils.common import (
@@ -108,6 +108,7 @@ def attribute(
108108
return_convergence_delta: Literal[False] = False,
109109
attribute_to_layer_input: bool = False,
110110
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
111+
grad_kwargs: Optional[Dict[str, Any]] = None,
111112
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
112113

113114
@typing.overload
@@ -121,6 +122,7 @@ def attribute(
121122
return_convergence_delta: Literal[True],
122123
attribute_to_layer_input: bool = False,
123124
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
125+
grad_kwargs: Optional[Dict[str, Any]] = None,
124126
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
125127

126128
@log_usage()
@@ -133,6 +135,7 @@ def attribute(
133135
return_convergence_delta: bool = False,
134136
attribute_to_layer_input: bool = False,
135137
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
138+
grad_kwargs: Optional[Dict[str, Any]] = None,
136139
) -> Union[
137140
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
138141
]:
@@ -248,6 +251,9 @@ def attribute(
248251
`custom_attribution_func` returns a tuple of attribution
249252
tensors that have the same length as the `inputs`.
250253
Default: None
254+
grad_kwargs (Dict[str, Any], optional): Additional keyword
255+
arguments for torch.autograd.grad.
256+
Default: None
251257
252258
Returns:
253259
**attributions** or 2-element tuple of **attributions**, **delta**:
@@ -274,6 +280,7 @@ def attribute(
274280
it is not guaranteed and depends on the specifics of the
275281
`custom_attribution_func`.
276282
283+
277284
Examples::
278285
279286
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
@@ -326,6 +333,7 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
326333
inputs,
327334
attribute_to_layer_input=attribute_to_layer_input,
328335
output_fn=lambda out: chunk_output_fn(out),
336+
grad_kwargs=grad_kwargs,
329337
)
330338

331339
attr_inputs = tuple(map(lambda attr: attr[0], attrs))

captum/attr/_core/layer/layer_gradient_shap.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
import typing
4-
from typing import Any, Callable, cast, List, Tuple, Union
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import torch
@@ -242,6 +242,7 @@ def attribute(
242242
attribute to the input or output, is a single tensor.
243243
Support for multiple tensors will be added later.
244244
Default: False
245+
245246
Returns:
246247
**attributions** or 2-element tuple of **attributions**, **delta**:
247248
- **attributions** (*Tensor* or *tuple[Tensor, ...]*):
@@ -375,6 +376,7 @@ def attribute(
375376
additional_forward_args: Any = None,
376377
return_convergence_delta: Literal[False] = False,
377378
attribute_to_layer_input: bool = False,
379+
grad_kwargs: Optional[Dict[str, Any]] = None,
378380
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
379381

380382
@typing.overload
@@ -387,6 +389,7 @@ def attribute(
387389
*,
388390
return_convergence_delta: Literal[True],
389391
attribute_to_layer_input: bool = False,
392+
grad_kwargs: Optional[Dict[str, Any]] = None,
390393
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
391394

392395
@log_usage()
@@ -398,6 +401,7 @@ def attribute( # type: ignore
398401
additional_forward_args: Any = None,
399402
return_convergence_delta: bool = False,
400403
attribute_to_layer_input: bool = False,
404+
grad_kwargs: Optional[Dict[str, Any]] = None,
401405
) -> Union[
402406
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
403407
]:
@@ -420,6 +424,7 @@ def attribute( # type: ignore
420424
additional_forward_args,
421425
device_ids=self.device_ids,
422426
attribute_to_layer_input=attribute_to_layer_input,
427+
grad_kwargs=grad_kwargs,
423428
)
424429

425430
attr_baselines = _forward_layer_eval(

captum/attr/_core/layer/layer_gradient_x_activation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ def attribute(
133133
layer input, otherwise it will be computed with respect
134134
to layer output.
135135
Default: False
136-
grad_kwargs: Additional keyword arguments for torch.autograd.grad
137-
136+
grad_kwargs (Dict[str, Any], optional): Additional keyword
137+
arguments for torch.autograd.grad.
138+
Default: None
138139
Returns:
139140
*Tensor* or *tuple[Tensor, ...]* or list of **attributions**:
140141
- **attributions** (*Tensor*, *tuple[Tensor, ...]*, or *list*):

captum/attr/_core/neuron/neuron_conductance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
import warnings
3-
from typing import Any, Callable, List, Tuple, Union
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
44

55
import torch
66
from captum._utils.common import (
@@ -99,6 +99,7 @@ def attribute(
9999
method: str = "riemann_trapezoid",
100100
internal_batch_size: Union[None, int] = None,
101101
attribute_to_neuron_input: bool = False,
102+
grad_kwargs: Optional[Dict[str, Any]] = None,
102103
) -> TensorOrTupleOfTensorsGeneric:
103104
r"""
104105
Args:
@@ -311,6 +312,7 @@ def attribute(
311312
n_steps=n_steps,
312313
method=method,
313314
attribute_to_neuron_input=attribute_to_neuron_input,
315+
grad_kwargs=grad_kwargs,
314316
)
315317
return _format_output(is_inputs_tuple, attrs)
316318

@@ -325,6 +327,7 @@ def _attribute(
325327
method: str = "riemann_trapezoid",
326328
attribute_to_neuron_input: bool = False,
327329
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
330+
grad_kwargs: Optional[Dict[str, Any]] = None,
328331
) -> Tuple[Tensor, ...]:
329332

330333
num_examples = inputs[0].shape[0]
@@ -371,6 +374,7 @@ def _attribute(
371374
gradient_neuron_selector=neuron_selector,
372375
device_ids=self.device_ids,
373376
attribute_to_layer_input=attribute_to_neuron_input,
377+
grad_kwargs=grad_kwargs,
374378
)
375379

376380
mid_grads = _verify_select_neuron(layer_gradients, neuron_selector)

tests/attr/layer/test_grad_cam.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#!/usr/bin/env python3
22

33
import unittest
4-
from typing import Any, Tuple, Union
4+
from typing import Any, Dict, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import TensorLikeList
88
from captum.attr._core.layer.grad_cam import LayerGradCam
9+
from packaging import version
910
from tests.helpers import BaseTest
1011
from tests.helpers.basic import assertTensorTuplesAlmostEqual
1112
from tests.helpers.basic_models import (
@@ -119,6 +120,7 @@ def _grad_cam_test_assert(
119120
attribute_to_layer_input: bool = False,
120121
relu_attributions: bool = False,
121122
attr_dim_summation: bool = True,
123+
grad_kwargs: Optional[Dict[str, Any]] = None,
122124
) -> None:
123125
layer_gc = LayerGradCam(model, target_layer)
124126
self.assertFalse(layer_gc.multiplies_by_inputs)
@@ -129,11 +131,30 @@ def _grad_cam_test_assert(
129131
attribute_to_layer_input=attribute_to_layer_input,
130132
relu_attributions=relu_attributions,
131133
attr_dim_summation=attr_dim_summation,
134+
grad_kwargs=grad_kwargs,
132135
)
133136
assertTensorTuplesAlmostEqual(
134137
self, attributions, expected_activation, delta=0.01
135138
)
136139

140+
def test_relu_gradcam_with_unused_layer(self) -> None:
141+
if version.parse(torch.__version__) < version.parse("2.1.0"):
142+
raise unittest.SkipTest(
143+
"Skipping unused layed gradient test since it is not supported "
144+
"by torch version < 2.1"
145+
)
146+
net = BasicModel_MultiLayer(multi_input_module=True)
147+
inp = torch.tensor([[0.0, 6.0, 0.0]], requires_grad=True)
148+
gradcam = LayerGradCam(net, net.relu)
149+
attributions = gradcam.attribute(
150+
inputs=inp,
151+
target=0,
152+
grad_kwargs={"materialize_grads": True},
153+
)
154+
self.assertEqual(len(attributions), 1)
155+
self.assertEqual(list(attributions[0].shape), [1])
156+
self.assertAlmostEqual(attributions[0].sum(), 0)
157+
137158

138159
if __name__ == "__main__":
139160
unittest.main()

0 commit comments

Comments
 (0)