Skip to content

Commit e733e7e

Browse files
committed
Add Occlusion to Insights
1 parent 1b5a6dc commit e733e7e

File tree

3 files changed

+124
-62
lines changed

3 files changed

+124
-62
lines changed

captum/insights/api.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class FilterConfig(NamedTuple):
8989
arg: config.value # type: ignore
9090
for arg, config in ATTRIBUTION_METHOD_CONFIG[
9191
IntegratedGradients.get_name()
92-
].items()
92+
].params.items()
9393
}
9494
prediction: str = "all"
9595
classes: List[str] = []
@@ -221,6 +221,12 @@ def _calculate_attribution(
221221
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[self._config.attribution_method]
222222
attribution_method = attribution_cls(net)
223223
args = self._config.attribution_arguments
224+
param_config = ATTRIBUTION_METHOD_CONFIG[self._config.attribution_method]
225+
if param_config.post_process:
226+
for k, v in args.items():
227+
if k in param_config.post_process:
228+
args[k] = param_config.post_process[k](v)
229+
224230
# TODO support multiple baselines
225231
baseline = baselines[0] if baselines and len(baselines) > 0 else None
226232
label = (
@@ -542,6 +548,8 @@ def get_insights_config(self):
542548
return {
543549
"classes": self.classes,
544550
"methods": list(ATTRIBUTION_NAMES_TO_METHODS.keys()),
545-
"method_arguments": namedtuple_to_dict(ATTRIBUTION_METHOD_CONFIG),
551+
"method_arguments": namedtuple_to_dict(
552+
{k: v.params for (k, v) in ATTRIBUTION_METHOD_CONFIG.items()}
553+
),
546554
"selected_method": self._config.attribution_method,
547555
}

captum/insights/config.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Dict, List, NamedTuple, Optional, Tuple
2+
from typing import Dict, List, NamedTuple, Optional, Tuple, Callable, Any, Union
33

44
from captum.attr import (
55
Deconvolution,
@@ -9,10 +9,14 @@
99
InputXGradient,
1010
IntegratedGradients,
1111
Saliency,
12+
Occlusion,
1213
)
1314
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
1415

1516

17+
Config = Union[NumberConfig, StrEnumConfig, StrConfig]
18+
19+
1620
class NumberConfig(NamedTuple):
1721
value: int = 1
1822
limit: Tuple[Optional[int], Optional[int]] = (None, None)
@@ -25,6 +29,11 @@ class StrEnumConfig(NamedTuple):
2529
type: str = "enum"
2630

2731

32+
class StrConfig(NamedTuple):
33+
value: str
34+
type: str = "string"
35+
36+
2837
SUPPORTED_ATTRIBUTION_METHODS = [
2938
Deconvolution,
3039
DeepLift,
@@ -33,20 +42,47 @@ class StrEnumConfig(NamedTuple):
3342
IntegratedGradients,
3443
Saliency,
3544
FeatureAblation,
45+
Occlusion,
3646
]
3747

48+
49+
class ConfigParameters(NamedTuple):
50+
params: Dict[str, Config]
51+
help_info: Optional[str] = None # TODO fill out help for each method
52+
post_process: Optional[Dict[str, Callable[[Any], Any]]] = None
53+
54+
3855
ATTRIBUTION_NAMES_TO_METHODS = {
3956
# mypy bug - treating it as a type instead of a class
4057
cls.get_name(): cls # type: ignore
4158
for cls in SUPPORTED_ATTRIBUTION_METHODS
4259
}
4360

44-
ATTRIBUTION_METHOD_CONFIG: Dict[str, Dict[str, tuple]] = {
45-
IntegratedGradients.get_name(): {
46-
"n_steps": NumberConfig(value=25, limit=(2, None)),
47-
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
48-
},
49-
FeatureAblation.get_name(): {
50-
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
51-
},
61+
62+
def _str_to_tuple(s):
63+
if isinstance(s, tuple):
64+
return s
65+
return tuple([int(i) for i in s.split()])
66+
67+
68+
ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = {
69+
IntegratedGradients.get_name(): ConfigParameters(
70+
params={
71+
"n_steps": NumberConfig(value=25, limit=(2, None)),
72+
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
73+
}
74+
),
75+
FeatureAblation.get_name(): ConfigParameters(
76+
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))}
77+
),
78+
Occlusion.get_name(): ConfigParameters(
79+
params={
80+
"sliding_window_shapes": StrConfig(value=""),
81+
"strides": StrConfig(value=""),
82+
},
83+
post_process={
84+
"sliding_window_shapes": _str_to_tuple,
85+
"strides": _str_to_tuple,
86+
},
87+
),
5288
}

captum/insights/frontend/src/App.js

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import "./App.css";
88
const ConfigType = Object.freeze({
99
Number: "number",
1010
Enum: "enum",
11+
String: "string",
1112
});
1213

1314
const Plot = createPlotlyComponent(Plotly);
@@ -153,62 +154,71 @@ class FilterContainer extends React.Component {
153154
}
154155
}
155156

156-
class ClassFilter extends React.Component {
157-
render() {
158-
return (
159-
<ReactTags
160-
tags={this.props.classes}
161-
autofocus={false}
162-
suggestions={this.props.suggestedClasses}
163-
handleDelete={this.props.handleClassDelete}
164-
handleAddition={this.props.handleClassAdd}
165-
minQueryLength={0}
166-
placeholder="add new class..."
157+
function ClassFilter(props) {
158+
return (
159+
<ReactTags
160+
tags={props.classes}
161+
autofocus={false}
162+
suggestions={props.suggestedClasses}
163+
handleDelete={props.handleClassDelete}
164+
handleAddition={props.handleClassAdd}
165+
minQueryLength={0}
166+
placeholder="add new class..."
167+
/>
168+
);
169+
}
170+
171+
function NumberArgument(props) {
172+
var min = props.limit[0];
173+
var max = props.limit[1];
174+
return (
175+
<div>
176+
{props.name}:
177+
<input
178+
className={cx([styles.input, styles["input--narrow"]])}
179+
name={props.name}
180+
type="number"
181+
value={props.value}
182+
min={min}
183+
max={max}
184+
onChange={props.handleInputChange}
167185
/>
168-
);
169-
}
186+
</div>
187+
);
170188
}
171189

172-
class NumberArgument extends React.Component {
173-
render() {
174-
var min = this.props.limit[0];
175-
var max = this.props.limit[1];
176-
return (
177-
<div>
178-
{this.props.name + ": "}
179-
<input
180-
className={cx([styles.input, styles["input--narrow"]])}
181-
name={this.props.name}
182-
type="number"
183-
value={this.props.value}
184-
min={min}
185-
max={max}
186-
onChange={this.props.handleInputChange}
187-
/>
188-
</div>
189-
);
190-
}
190+
function EnumArgument(props) {
191+
const options = props.limit.map((item, key) => (
192+
<option value={item}>{item}</option>
193+
));
194+
return (
195+
<div>
196+
{props.name}:
197+
<select
198+
className={styles.select}
199+
name={props.name}
200+
value={props.value}
201+
onChange={props.handleInputChange}
202+
>
203+
{options}
204+
</select>
205+
</div>
206+
);
191207
}
192208

193-
class EnumArgument extends React.Component {
194-
render() {
195-
const options = this.props.limit.map((item, key) => (
196-
<option value={item}>{item}</option>
197-
));
198-
return (
199-
<div>
200-
{this.props.name + ": "}
201-
<select
202-
className={styles.select}
203-
name={this.props.name}
204-
value={this.props.value}
205-
onChange={this.props.handleInputChange}
206-
>
207-
{options}
208-
</select>
209-
</div>
210-
);
211-
}
209+
function StringArgument(props) {
210+
return (
211+
<div>
212+
{props.name}:
213+
<input
214+
className={cx([styles.input, styles["input--narrow"]])}
215+
name={props.name}
216+
type="text"
217+
value={props.value}
218+
onChange={props.handleInputChange}
219+
/>
220+
</div>
221+
);
212222
}
213223

214224
class Filter extends React.Component {
@@ -232,6 +242,14 @@ class Filter extends React.Component {
232242
handleInputChange={this.props.handleArgumentChange}
233243
/>
234244
);
245+
case ConfigType.String:
246+
return (
247+
<StringArgument
248+
name={name}
249+
value={config.value}
250+
handleInputChange={this.props.handleArgumentChange}
251+
/>
252+
);
235253
}
236254
};
237255

0 commit comments

Comments
 (0)