Skip to content

Commit f226306

Browse files
committed
refactor feature ablation
1 parent 6c9045d commit f226306

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -286,21 +286,22 @@ def attribute(
286286
if show_progress:
287287
attr_progress.update()
288288

289+
# number of elements in the output of forward_func
290+
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1
291+
292+
# flattent eval outputs into 1D (n_outputs)
293+
# add the leading dim for n_feature_perturbed
294+
if isinstance(initial_eval, Tensor):
295+
initial_eval = initial_eval.reshape(1, -1)
296+
289297
agg_output_mode = FeatureAblation._find_output_mode(
290298
perturbations_per_eval, feature_mask
291299
)
292300

293-
# get as a 2D tensor (if it is not a scalar)
294-
if isinstance(initial_eval, torch.Tensor):
295-
initial_eval = initial_eval.reshape(1, -1)
296-
num_outputs = initial_eval.shape[1]
297-
else:
298-
num_outputs = 1
299-
300301
if not agg_output_mode:
301302
assert (
302-
isinstance(initial_eval, torch.Tensor)
303-
and num_outputs == num_examples
303+
isinstance(initial_eval, Tensor)
304+
and n_outputs == num_examples
304305
), (
305306
"expected output of `forward_func` to have "
306307
+ "`batch_size` elements for perturbations_per_eval > 1 "
@@ -316,8 +317,9 @@ def attribute(
316317
)
317318

318319
total_attrib = [
320+
# attribute w.r.t each output element
319321
torch.zeros(
320-
(num_outputs,) + input.shape[1:],
322+
(n_outputs, *input.shape[1:]),
321323
dtype=attrib_type,
322324
device=input.device,
323325
)
@@ -328,7 +330,7 @@ def attribute(
328330
if self.use_weights:
329331
weights = [
330332
torch.zeros(
331-
(num_outputs,) + input.shape[1:], device=input.device
333+
(n_outputs, *input.shape[1:]), device=input.device
332334
).float()
333335
for input in inputs
334336
]
@@ -354,8 +356,11 @@ def attribute(
354356
perturbations_per_eval,
355357
**kwargs,
356358
):
357-
# modified_eval dimensions: 1D tensor with length
358-
# equal to #num_examples * #features in batch
359+
# modified_eval has (n_feature_perturbed * n_outputs) elements
360+
# shape:
361+
# agg mode: (*initial_eval.shape)
362+
# non-agg mode:
363+
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
359364
modified_eval = _run_forward(
360365
self.forward_func,
361366
current_inputs,
@@ -366,25 +371,34 @@ def attribute(
366371
if show_progress:
367372
attr_progress.update()
368373

369-
# (contains 1 more dimension than inputs). This adds extra
370-
# dimensions of 1 to make the tensor broadcastable with the inputs
371-
# tensor.
372374
if not isinstance(modified_eval, torch.Tensor):
373375
eval_diff = initial_eval - modified_eval
374376
else:
375377
if not agg_output_mode:
378+
# current_batch_size is not n_examples
379+
# it may get expanded by n_feature_perturbed
380+
current_batch_size = current_inputs[0].shape[0]
376381
assert (
377-
modified_eval.numel() == current_inputs[0].shape[0]
382+
modified_eval.numel() == current_batch_size
378383
), """expected output of forward_func to grow with
379384
batch_size. If this is not the case for your model
380385
please set perturbations_per_eval = 1"""
381386

382-
eval_diff = (
383-
initial_eval - modified_eval.reshape((-1, num_outputs))
384-
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
387+
# reshape the leading dim for n_feature_perturbed
388+
# flatten each feature's eval outputs into 1D of (n_outputs)
389+
modified_eval = modified_eval.reshape(-1, n_outputs)
390+
# eval_diff in shape (n_feature_perturbed, n_outputs)
391+
eval_diff = initial_eval - modified_eval
392+
393+
# append the shape of one input example
394+
# to make it broadcastable to mask
395+
eval_diff = eval_diff.reshape(
396+
eval_diff.shape + (inputs[i].dim() - 1) * (1,)
397+
)
385398
eval_diff = eval_diff.to(total_attrib[i].device)
386399
if self.use_weights:
387400
weights[i] += current_mask.float().sum(dim=0)
401+
388402
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
389403
dim=0
390404
)

0 commit comments

Comments
 (0)