Skip to content

Commit ebd39a0

Browse files
miguelmartin75vivekmig
authored andcommitted
add input wrapper for layer methods (#534)
Summary: Pull Request resolved: #534 Introduces a utility class called `ModelInputWrapper` to wrap over a model in order to treat inputs as separate layers. This does so by mapping each input fed to `forward` using an `Identity` operation. This way if attribute_to_inputs=True or False it should work. Add two tests: - Test whether _foward_layer_eval retrieves the appropriate input values - Compare regular IG with layer IG and layer wrapped inputs Updated tutorial and documentation Reviewed By: NarineK Differential Revision: D25110896 fbshipit-source-id: bb8dd4947ae88e183af94c09cf906f9687fbe8ff
1 parent 3d9a649 commit ebd39a0

File tree

6 files changed

+386
-200
lines changed

6 files changed

+386
-200
lines changed

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def flatten_tuple(tup):
362362

363363
if self.device_ids is None:
364364
self.device_ids = getattr(self.forward_func, "device_ids", None)
365+
365366
inputs_layer = _forward_layer_eval(
366367
self.forward_func,
367368
inps,
@@ -398,7 +399,7 @@ def gradient_func(
398399
target_ind: TargetType = None,
399400
additional_forward_args: Any = None,
400401
) -> Tuple[Tensor, ...]:
401-
if self.device_ids is None:
402+
if self.device_ids is None or len(self.device_ids) == 0:
402403
scattered_inputs = (inputs,)
403404
else:
404405
# scatter method does not have a precise enough return type in its
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python3
2+
3+
import inspect
4+
from typing import Any
5+
6+
import torch.nn as nn
7+
8+
9+
class InputIdentity(nn.Module):
10+
def __init__(self, input_name: str) -> None:
11+
r"""
12+
The identity operation
13+
14+
Args:
15+
input_name (str)
16+
The name of the input this layer is associated to. For debugging
17+
purposes.
18+
"""
19+
super().__init__()
20+
self.input_name = input_name
21+
22+
def forward(self, x):
23+
return x
24+
25+
26+
class ModelInputWrapper(nn.Module):
27+
def __init__(self, module_to_wrap: nn.Module) -> None:
28+
r"""
29+
This is a convenience class. This wraps a model via first feeding the
30+
model's inputs to separate layers (one for each input) and then feeding
31+
the (unmodified) inputs to the underlying model (`module_to_wrap`). Each
32+
input is fed through an `InputIdentity` layer/module. This class does
33+
not change how you feed inputs to your model, so feel free to use your
34+
model as you normally would.
35+
36+
To access a wrapped input layer, simply access it via the `input_maps`
37+
ModuleDict, e.g. to get the corresponding module for input "x", simply
38+
provide/write `my_wrapped_module.input_maps["x"]`
39+
40+
This is done such that one can use layer attribution methods on inputs.
41+
Which should allow you to use mix layers with inputs with these
42+
attribution methods. This is especially useful multimodal models which
43+
input discrete features (mapped to embeddings, such as text) and regular
44+
continuous feature vectors.
45+
46+
Notes:
47+
- Since inputs are mapped with the identity, attributing to the
48+
input/feature can be done with either the input or output of the
49+
layer, e.g. attributing to an input/feature doesn't depend on whether
50+
attribute_to_layer_input is True or False for
51+
LayerIntegratedGradients.
52+
- Please refer to the multimodal tutorial or unit tests
53+
(test/attr/test_layer_wrapper.py) for an example.
54+
55+
Args:
56+
module_to_wrap (nn.Module):
57+
The model/module you want to wrap
58+
"""
59+
super().__init__()
60+
self.module = module_to_wrap
61+
62+
# ignore self
63+
self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:]
64+
self.input_maps = nn.ModuleDict(
65+
{arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list}
66+
)
67+
68+
def forward(self, *args, **kwargs) -> Any:
69+
args = list(args)
70+
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)):
71+
args[idx] = self.input_maps[arg_name](arg)
72+
73+
for arg_name in kwargs.keys():
74+
kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])
75+
76+
return self.module(*tuple(args), **kwargs)

tests/attr/helpers/test_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from captum.attr._core.occlusion import Occlusion
3737
from captum.attr._core.saliency import Saliency
3838
from captum.attr._core.shapley_value import ShapleyValueSampling
39+
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
3940
from tests.helpers.basic import set_all_random_seeds
4041
from tests.helpers.basic_models import (
4142
BasicModel_ConvNet,
@@ -1160,4 +1161,19 @@
11601161
"target": 0,
11611162
},
11621163
},
1164+
{
1165+
"name": "basic_layer_ig_multi_layer_multi_output_with_input_wrapper",
1166+
"algorithms": [LayerIntegratedGradients],
1167+
"model": ModelInputWrapper(BasicModel_MultiLayer_TrueMultiInput()),
1168+
"layer": ["module.m1", "module.m234"],
1169+
"attribute_args": {
1170+
"inputs": (
1171+
torch.randn(5, 3),
1172+
torch.randn(5, 3),
1173+
torch.randn(5, 3),
1174+
torch.randn(5, 3),
1175+
),
1176+
"target": 0,
1177+
},
1178+
},
11631179
]
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#!/usr/bin/env python3
2+
3+
import functools
4+
import inspect
5+
from typing import Callable, Dict, Tuple
6+
7+
import torch
8+
9+
from captum._utils.gradient import _forward_layer_eval
10+
from captum.attr import (
11+
DeepLift,
12+
DeepLiftShap,
13+
FeatureAblation,
14+
GradientShap,
15+
InputXGradient,
16+
IntegratedGradients,
17+
LayerDeepLift,
18+
LayerDeepLiftShap,
19+
LayerFeatureAblation,
20+
LayerGradientShap,
21+
LayerGradientXActivation,
22+
LayerIntegratedGradients,
23+
)
24+
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
25+
from tests.helpers.basic import BaseTest, assertTensorTuplesAlmostEqual
26+
from tests.helpers.basic_models import (
27+
BasicModel,
28+
BasicModel_MultiLayer_TrueMultiInput,
29+
MixedKwargsAndArgsModule,
30+
)
31+
32+
layer_methods_to_test_with_equiv = [
33+
# layer_method, equiv_method, whether or not to use multiple layers
34+
(LayerIntegratedGradients, IntegratedGradients, [True, False]),
35+
(LayerGradientXActivation, InputXGradient, [True, False]),
36+
(LayerFeatureAblation, FeatureAblation, [False]),
37+
(LayerDeepLift, DeepLift, [False]),
38+
(LayerDeepLiftShap, DeepLiftShap, [False]),
39+
(LayerGradientShap, GradientShap, [False]),
40+
# TODO: add other algorithms here
41+
]
42+
43+
44+
class InputLayerMeta(type):
45+
def __new__(cls, name: str, bases: Tuple, attrs: Dict):
46+
for (
47+
layer_method,
48+
equiv_method,
49+
multi_layers,
50+
) in layer_methods_to_test_with_equiv:
51+
for multi_layer in multi_layers:
52+
test_name = (
53+
f"test_{layer_method.__name__}"
54+
+ f"_{equiv_method.__name__}_{multi_layer}"
55+
)
56+
attrs[
57+
test_name
58+
] = lambda self: self.layer_method_with_input_layer_patches(
59+
layer_method, equiv_method, multi_layer
60+
)
61+
62+
return super(InputLayerMeta, cls).__new__(cls, name, bases, attrs)
63+
64+
65+
class TestInputLayerWrapper(BaseTest, metaclass=InputLayerMeta):
66+
def test_forward_layer_eval_on_mixed_args_kwargs_module(self) -> None:
67+
x = torch.randn(10, 5)
68+
y = torch.randn(10, 5)
69+
70+
model = MixedKwargsAndArgsModule()
71+
72+
self.forward_eval_layer_with_inputs_helper(model, {"x": x})
73+
self.forward_eval_layer_with_inputs_helper(model, {"x": x, "y": y})
74+
75+
def layer_method_with_input_layer_patches(
76+
self,
77+
layer_method_class: Callable,
78+
equiv_method_class: Callable,
79+
multi_layer: bool,
80+
) -> None:
81+
model = BasicModel_MultiLayer_TrueMultiInput() if multi_layer else BasicModel()
82+
83+
input_names = ["x1", "x2", "x3", "x4"] if multi_layer else ["input"]
84+
model = ModelInputWrapper(model)
85+
86+
layers = [model.input_maps[inp] for inp in input_names]
87+
layer_method = layer_method_class(
88+
model, layer=layers if multi_layer else layers[0]
89+
)
90+
equivalent_method = equiv_method_class(model)
91+
92+
inputs = tuple(torch.rand(5, 3) for _ in input_names)
93+
baseline = tuple(torch.zeros(5, 3) for _ in input_names)
94+
95+
args = inspect.getfullargspec(equivalent_method.attribute.__wrapped__).args
96+
97+
args_to_use = [inputs]
98+
if "baselines" in args:
99+
args_to_use += [baseline]
100+
101+
a1 = layer_method.attribute(*args_to_use, target=0)
102+
a2 = layer_method.attribute(
103+
*args_to_use, target=0, attribute_to_layer_input=True
104+
)
105+
106+
real_attributions = equivalent_method.attribute(*args_to_use, target=0)
107+
108+
if not isinstance(a1, tuple):
109+
a1 = (a1,)
110+
a2 = (a2,)
111+
112+
if not isinstance(real_attributions, tuple):
113+
real_attributions = (real_attributions,)
114+
115+
assertTensorTuplesAlmostEqual(self, a1, a2)
116+
assertTensorTuplesAlmostEqual(self, a1, real_attributions)
117+
118+
def forward_eval_layer_with_inputs_helper(self, model, inputs_to_test):
119+
# hard coding for simplicity
120+
# 0 if using args, 1 if using kwargs
121+
# => no 0s after first 1 (left to right)
122+
#
123+
# used to test utilization of args/kwargs
124+
use_args_or_kwargs = [
125+
[[0], [1]],
126+
[
127+
[0, 0],
128+
[0, 1],
129+
[1, 1],
130+
],
131+
]
132+
133+
model = ModelInputWrapper(model)
134+
135+
def forward_func(*args, args_or_kwargs=None):
136+
# convert to args or kwargs to test *args and **kwargs wrapping behavior
137+
new_args = []
138+
new_kwargs = {}
139+
for args_or_kwarg, name, inp in zip(
140+
args_or_kwargs, inputs_to_test.keys(), args
141+
):
142+
if args_or_kwarg:
143+
new_kwargs[name] = inp
144+
else:
145+
new_args.append(inp)
146+
return model(*new_args, **new_kwargs)
147+
148+
for args_or_kwargs in use_args_or_kwargs[len(inputs_to_test) - 1]:
149+
with self.subTest(args_or_kwargs=args_or_kwargs):
150+
inputs = _forward_layer_eval(
151+
functools.partial(forward_func, args_or_kwargs=args_or_kwargs),
152+
inputs=tuple(inputs_to_test.values()),
153+
layer=[model.input_maps[name] for name in inputs_to_test.keys()],
154+
)
155+
156+
inputs_with_attrib_to_inp = _forward_layer_eval(
157+
functools.partial(forward_func, args_or_kwargs=args_or_kwargs),
158+
inputs=tuple(inputs_to_test.values()),
159+
layer=[model.input_maps[name] for name in inputs_to_test.keys()],
160+
attribute_to_layer_input=True,
161+
)
162+
163+
for i1, i2, i3 in zip(
164+
inputs, inputs_with_attrib_to_inp, inputs_to_test.values()
165+
):
166+
self.assertTrue((i1[0] == i2[0]).all())
167+
self.assertTrue((i1[0] == i3).all())

tests/helpers/basic_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
"""
1616

1717

18+
class MixedKwargsAndArgsModule(nn.Module):
19+
def __init__(self):
20+
super().__init__()
21+
22+
def forward(self, x, y=None):
23+
if y is not None:
24+
return x + y
25+
return x
26+
27+
1828
class BasicModel(nn.Module):
1929
def __init__(self):
2030
super().__init__()

0 commit comments

Comments
 (0)