Skip to content

Commit 5f878af

Browse files
aobo-yfacebook-github-bot
authored andcommitted
refactor feature ablation (#1047)
Summary: - refactor to feature ablation to make the logic about multi-output attribution more readable - removed some (outdated) misleading comments - add more comments to explain the steps and tensor shapes Pull Request resolved: #1047 Reviewed By: vivekmig Differential Revision: D40388808 Pulled By: aobo-y fbshipit-source-id: fb105a302fe3eeef302a1c50e39847e783bb48e2
1 parent cde34a2 commit 5f878af

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -286,22 +286,20 @@ def attribute(
286286
if show_progress:
287287
attr_progress.update()
288288

289+
# number of elements in the output of forward_func
290+
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1
291+
292+
# flatten eval outputs into 1D (n_outputs)
293+
# add the leading dim for n_feature_perturbed
294+
if isinstance(initial_eval, Tensor):
295+
initial_eval = initial_eval.reshape(1, -1)
296+
289297
agg_output_mode = FeatureAblation._find_output_mode(
290298
perturbations_per_eval, feature_mask
291299
)
292300

293-
# get as a 2D tensor (if it is not a scalar)
294-
if isinstance(initial_eval, torch.Tensor):
295-
initial_eval = initial_eval.reshape(1, -1)
296-
num_outputs = initial_eval.shape[1]
297-
else:
298-
num_outputs = 1
299-
300301
if not agg_output_mode:
301-
assert (
302-
isinstance(initial_eval, torch.Tensor)
303-
and num_outputs == num_examples
304-
), (
302+
assert isinstance(initial_eval, Tensor) and n_outputs == num_examples, (
305303
"expected output of `forward_func` to have "
306304
+ "`batch_size` elements for perturbations_per_eval > 1 "
307305
+ "and all feature_mask.shape[0] > 1"
@@ -316,8 +314,9 @@ def attribute(
316314
)
317315

318316
total_attrib = [
317+
# attribute w.r.t each output element
319318
torch.zeros(
320-
(num_outputs,) + input.shape[1:],
319+
(n_outputs,) + input.shape[1:],
321320
dtype=attrib_type,
322321
device=input.device,
323322
)
@@ -328,7 +327,7 @@ def attribute(
328327
if self.use_weights:
329328
weights = [
330329
torch.zeros(
331-
(num_outputs,) + input.shape[1:], device=input.device
330+
(n_outputs,) + input.shape[1:], device=input.device
332331
).float()
333332
for input in inputs
334333
]
@@ -354,8 +353,11 @@ def attribute(
354353
perturbations_per_eval,
355354
**kwargs,
356355
):
357-
# modified_eval dimensions: 1D tensor with length
358-
# equal to #num_examples * #features in batch
356+
# modified_eval has (n_feature_perturbed * n_outputs) elements
357+
# shape:
358+
# agg mode: (*initial_eval.shape)
359+
# non-agg mode:
360+
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
359361
modified_eval = _run_forward(
360362
self.forward_func,
361363
current_inputs,
@@ -366,25 +368,34 @@ def attribute(
366368
if show_progress:
367369
attr_progress.update()
368370

369-
# (contains 1 more dimension than inputs). This adds extra
370-
# dimensions of 1 to make the tensor broadcastable with the inputs
371-
# tensor.
372371
if not isinstance(modified_eval, torch.Tensor):
373372
eval_diff = initial_eval - modified_eval
374373
else:
375374
if not agg_output_mode:
375+
# current_batch_size is not n_examples
376+
# it may get expanded by n_feature_perturbed
377+
current_batch_size = current_inputs[0].shape[0]
376378
assert (
377-
modified_eval.numel() == current_inputs[0].shape[0]
379+
modified_eval.numel() == current_batch_size
378380
), """expected output of forward_func to grow with
379381
batch_size. If this is not the case for your model
380382
please set perturbations_per_eval = 1"""
381383

382-
eval_diff = (
383-
initial_eval - modified_eval.reshape((-1, num_outputs))
384-
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
384+
# reshape the leading dim for n_feature_perturbed
385+
# flatten each feature's eval outputs into 1D of (n_outputs)
386+
modified_eval = modified_eval.reshape(-1, n_outputs)
387+
# eval_diff in shape (n_feature_perturbed, n_outputs)
388+
eval_diff = initial_eval - modified_eval
389+
390+
# append the shape of one input example
391+
# to make it broadcastable to mask
392+
eval_diff = eval_diff.reshape(
393+
eval_diff.shape + (inputs[i].dim() - 1) * (1,)
394+
)
385395
eval_diff = eval_diff.to(total_attrib[i].device)
386396
if self.use_weights:
387397
weights[i] += current_mask.float().sum(dim=0)
398+
388399
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
389400
dim=0
390401
)

0 commit comments

Comments
 (0)