Skip to content

Commit c2062db

Browse files
edward-iofacebook-github-bot
authored andcommitted
Add Occlusion to Insights (#369)
Summary: ![image](https://user-images.githubusercontent.com/53842584/81026858-6efefb00-8e30-11ea-970d-5c6907fe3e7b.png) Pull Request resolved: #369 Reviewed By: vivekmig, J0Nreynolds Differential Revision: D21394665 Pulled By: edward-io fbshipit-source-id: 4f6848928fa271b99ee8a376b6232985fc739b2c
1 parent 77aa93a commit c2062db

File tree

4 files changed

+137
-74
lines changed

4 files changed

+137
-74
lines changed

captum/insights/api.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class FilterConfig(NamedTuple):
8989
arg: config.value # type: ignore
9090
for arg, config in ATTRIBUTION_METHOD_CONFIG[
9191
IntegratedGradients.get_name()
92-
].items()
92+
].params.items()
9393
}
9494
prediction: str = "all"
9595
classes: List[str] = []
@@ -221,6 +221,12 @@ def _calculate_attribution(
221221
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[self._config.attribution_method]
222222
attribution_method = attribution_cls(net)
223223
args = self._config.attribution_arguments
224+
param_config = ATTRIBUTION_METHOD_CONFIG[self._config.attribution_method]
225+
if param_config.post_process:
226+
for k, v in args.items():
227+
if k in param_config.post_process:
228+
args[k] = param_config.post_process[k](v)
229+
224230
# TODO support multiple baselines
225231
baseline = baselines[0] if baselines and len(baselines) > 0 else None
226232
label = (
@@ -329,7 +335,9 @@ def _serve_colab(self, blocking=False, debug=False, port=None):
329335
def _get_labels_from_scores(
330336
self, scores: Tensor, indices: Tensor
331337
) -> List[OutputScore]:
332-
pred_scores = []
338+
pred_scores: List[OutputScore] = []
339+
if indices.nelement() < 2:
340+
return pred_scores
333341
for i in range(len(indices)):
334342
score = scores[i]
335343
pred_scores.append(
@@ -542,6 +550,8 @@ def get_insights_config(self):
542550
return {
543551
"classes": self.classes,
544552
"methods": list(ATTRIBUTION_NAMES_TO_METHODS.keys()),
545-
"method_arguments": namedtuple_to_dict(ATTRIBUTION_METHOD_CONFIG),
553+
"method_arguments": namedtuple_to_dict(
554+
{k: v.params for (k, v) in ATTRIBUTION_METHOD_CONFIG.items()}
555+
),
546556
"selected_method": self._config.attribution_method,
547557
}

captum/insights/config.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Dict, List, NamedTuple, Optional, Tuple
2+
from typing import Dict, List, NamedTuple, Optional, Tuple, Callable, Any, Union
33

44
from captum.attr import (
55
Deconvolution,
@@ -9,6 +9,7 @@
99
InputXGradient,
1010
IntegratedGradients,
1111
Saliency,
12+
Occlusion,
1213
)
1314
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
1415

@@ -25,6 +26,13 @@ class StrEnumConfig(NamedTuple):
2526
type: str = "enum"
2627

2728

29+
class StrConfig(NamedTuple):
30+
value: str
31+
type: str = "string"
32+
33+
34+
Config = Union[NumberConfig, StrEnumConfig, StrConfig]
35+
2836
SUPPORTED_ATTRIBUTION_METHODS = [
2937
Deconvolution,
3038
DeepLift,
@@ -33,20 +41,50 @@ class StrEnumConfig(NamedTuple):
3341
IntegratedGradients,
3442
Saliency,
3543
FeatureAblation,
44+
Occlusion,
3645
]
3746

47+
48+
class ConfigParameters(NamedTuple):
49+
params: Dict[str, Config]
50+
help_info: Optional[str] = None # TODO fill out help for each method
51+
post_process: Optional[Dict[str, Callable[[Any], Any]]] = None
52+
53+
3854
ATTRIBUTION_NAMES_TO_METHODS = {
3955
# mypy bug - treating it as a type instead of a class
4056
cls.get_name(): cls # type: ignore
4157
for cls in SUPPORTED_ATTRIBUTION_METHODS
4258
}
4359

44-
ATTRIBUTION_METHOD_CONFIG: Dict[str, Dict[str, tuple]] = {
45-
IntegratedGradients.get_name(): {
46-
"n_steps": NumberConfig(value=25, limit=(2, None)),
47-
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
48-
},
49-
FeatureAblation.get_name(): {
50-
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
51-
},
60+
61+
def _str_to_tuple(s):
62+
if isinstance(s, tuple):
63+
return s
64+
return tuple([int(i) for i in s.split()])
65+
66+
67+
ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = {
68+
IntegratedGradients.get_name(): ConfigParameters(
69+
params={
70+
"n_steps": NumberConfig(value=25, limit=(2, None)),
71+
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
72+
},
73+
post_process={"n_steps": int},
74+
),
75+
FeatureAblation.get_name(): ConfigParameters(
76+
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))},
77+
),
78+
Occlusion.get_name(): ConfigParameters(
79+
params={
80+
"sliding_window_shapes": StrConfig(value=""),
81+
"strides": StrConfig(value=""),
82+
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
83+
},
84+
post_process={
85+
"sliding_window_shapes": _str_to_tuple,
86+
"strides": _str_to_tuple,
87+
"perturbations_per_eval": int,
88+
},
89+
),
5290
}

captum/insights/frontend/src/App.js

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import "./App.css";
88
const ConfigType = Object.freeze({
99
Number: "number",
1010
Enum: "enum",
11+
String: "string",
1112
});
1213

1314
const Plot = createPlotlyComponent(Plotly);
@@ -153,62 +154,71 @@ class FilterContainer extends React.Component {
153154
}
154155
}
155156

156-
class ClassFilter extends React.Component {
157-
render() {
158-
return (
159-
<ReactTags
160-
tags={this.props.classes}
161-
autofocus={false}
162-
suggestions={this.props.suggestedClasses}
163-
handleDelete={this.props.handleClassDelete}
164-
handleAddition={this.props.handleClassAdd}
165-
minQueryLength={0}
166-
placeholder="add new class..."
157+
function ClassFilter(props) {
158+
return (
159+
<ReactTags
160+
tags={props.classes}
161+
autofocus={false}
162+
suggestions={props.suggestedClasses}
163+
handleDelete={props.handleClassDelete}
164+
handleAddition={props.handleClassAdd}
165+
minQueryLength={0}
166+
placeholder="add new class..."
167+
/>
168+
);
169+
}
170+
171+
function NumberArgument(props) {
172+
var min = props.limit[0];
173+
var max = props.limit[1];
174+
return (
175+
<div>
176+
{props.name}:
177+
<input
178+
className={cx([styles.input, styles["input--narrow"]])}
179+
name={props.name}
180+
type="number"
181+
value={props.value}
182+
min={min}
183+
max={max}
184+
onChange={props.handleInputChange}
167185
/>
168-
);
169-
}
186+
</div>
187+
);
170188
}
171189

172-
class NumberArgument extends React.Component {
173-
render() {
174-
var min = this.props.limit[0];
175-
var max = this.props.limit[1];
176-
return (
177-
<div>
178-
{this.props.name + ": "}
179-
<input
180-
className={cx([styles.input, styles["input--narrow"]])}
181-
name={this.props.name}
182-
type="number"
183-
value={this.props.value}
184-
min={min}
185-
max={max}
186-
onChange={this.props.handleInputChange}
187-
/>
188-
</div>
189-
);
190-
}
190+
function EnumArgument(props) {
191+
const options = props.limit.map((item, key) => (
192+
<option value={item}>{item}</option>
193+
));
194+
return (
195+
<div>
196+
{props.name}:
197+
<select
198+
className={styles.select}
199+
name={props.name}
200+
value={props.value}
201+
onChange={props.handleInputChange}
202+
>
203+
{options}
204+
</select>
205+
</div>
206+
);
191207
}
192208

193-
class EnumArgument extends React.Component {
194-
render() {
195-
const options = this.props.limit.map((item, key) => (
196-
<option value={item}>{item}</option>
197-
));
198-
return (
199-
<div>
200-
{this.props.name + ": "}
201-
<select
202-
className={styles.select}
203-
name={this.props.name}
204-
value={this.props.value}
205-
onChange={this.props.handleInputChange}
206-
>
207-
{options}
208-
</select>
209-
</div>
210-
);
211-
}
209+
function StringArgument(props) {
210+
return (
211+
<div>
212+
{props.name}:
213+
<input
214+
className={cx([styles.input, styles["input--narrow"]])}
215+
name={props.name}
216+
type="text"
217+
value={props.value}
218+
onChange={props.handleInputChange}
219+
/>
220+
</div>
221+
);
212222
}
213223

214224
class Filter extends React.Component {
@@ -232,6 +242,14 @@ class Filter extends React.Component {
232242
handleInputChange={this.props.handleArgumentChange}
233243
/>
234244
);
245+
case ConfigType.String:
246+
return (
247+
<StringArgument
248+
name={name}
249+
value={config.value}
250+
handleInputChange={this.props.handleArgumentChange}
251+
/>
252+
);
235253
}
236254
};
237255

captum/insights/frontend/widget/src/Widget.js

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ class Widget extends React.Component {
1212
config: {
1313
classes: [],
1414
methods: [],
15-
method_arguments: {}
15+
method_arguments: {},
1616
},
1717
loading: false,
18-
callback: null
18+
callback: null,
1919
};
2020
this.backbone = this.props.backbone;
2121
}
@@ -47,14 +47,11 @@ class Widget extends React.Component {
4747

4848
_fetchInit = () => {
4949
this.setState({
50-
config: this.backbone.model.get("insights_config")
50+
config: this.backbone.model.get("insights_config"),
5151
});
5252
};
5353

54-
fetchData = filterConfig => {
55-
filterConfig.approximation_steps = parseInt(
56-
filterConfig.approximation_steps
57-
);
54+
fetchData = (filterConfig) => {
5855
this.setState({ loading: true }, () => {
5956
this.backbone.model.save({ config: filterConfig, output: [] });
6057
});
@@ -64,7 +61,7 @@ class Widget extends React.Component {
6461
this.setState({ callback: callback }, () => {
6562
this.backbone.model.save({
6663
label_details: { labelIndex, instance },
67-
attribution: {}
64+
attribution: {},
6865
});
6966
});
7067
};
@@ -90,16 +87,16 @@ var CaptumInsightsModel = widgets.DOMWidgetModel.extend({
9087
_model_module: "jupyter-captum-insights",
9188
_view_module: "jupyter-captum-insights",
9289
_model_module_version: "0.1.0",
93-
_view_module_version: "0.1.0"
94-
})
90+
_view_module_version: "0.1.0",
91+
}),
9592
});
9693

9794
var CaptumInsightsView = widgets.DOMWidgetView.extend({
9895
initialize() {
9996
const $app = document.createElement("div");
10097
ReactDOM.render(<Widget backbone={this} />, $app);
10198
this.el.append($app);
102-
}
99+
},
103100
});
104101

105102
export { Widget as default, CaptumInsightsModel, CaptumInsightsView };

0 commit comments

Comments
 (0)