@@ -286,22 +286,20 @@ def attribute(
286
286
if show_progress :
287
287
attr_progress .update ()
288
288
289
+ # number of elements in the output of forward_func
290
+ n_outputs = initial_eval .numel () if isinstance (initial_eval , Tensor ) else 1
291
+
292
+ # flatten eval outputs into 1D (n_outputs)
293
+ # add the leading dim for n_feature_perturbed
294
+ if isinstance (initial_eval , Tensor ):
295
+ initial_eval = initial_eval .reshape (1 , - 1 )
296
+
289
297
agg_output_mode = FeatureAblation ._find_output_mode (
290
298
perturbations_per_eval , feature_mask
291
299
)
292
300
293
- # get as a 2D tensor (if it is not a scalar)
294
- if isinstance (initial_eval , torch .Tensor ):
295
- initial_eval = initial_eval .reshape (1 , - 1 )
296
- num_outputs = initial_eval .shape [1 ]
297
- else :
298
- num_outputs = 1
299
-
300
301
if not agg_output_mode :
301
- assert (
302
- isinstance (initial_eval , torch .Tensor )
303
- and num_outputs == num_examples
304
- ), (
302
+ assert isinstance (initial_eval , Tensor ) and n_outputs == num_examples , (
305
303
"expected output of `forward_func` to have "
306
304
+ "`batch_size` elements for perturbations_per_eval > 1 "
307
305
+ "and all feature_mask.shape[0] > 1"
@@ -316,8 +314,9 @@ def attribute(
316
314
)
317
315
318
316
total_attrib = [
317
+ # attribute w.r.t each output element
319
318
torch .zeros (
320
- (num_outputs ,) + input .shape [1 :],
319
+ (n_outputs ,) + input .shape [1 :],
321
320
dtype = attrib_type ,
322
321
device = input .device ,
323
322
)
@@ -328,7 +327,7 @@ def attribute(
328
327
if self .use_weights :
329
328
weights = [
330
329
torch .zeros (
331
- (num_outputs ,) + input .shape [1 :], device = input .device
330
+ (n_outputs ,) + input .shape [1 :], device = input .device
332
331
).float ()
333
332
for input in inputs
334
333
]
@@ -354,8 +353,11 @@ def attribute(
354
353
perturbations_per_eval ,
355
354
** kwargs ,
356
355
):
357
- # modified_eval dimensions: 1D tensor with length
358
- # equal to #num_examples * #features in batch
356
+ # modified_eval has (n_feature_perturbed * n_outputs) elements
357
+ # shape:
358
+ # agg mode: (*initial_eval.shape)
359
+ # non-agg mode:
360
+ # (feature_perturbed * batch_size, *initial_eval.shape[1:])
359
361
modified_eval = _run_forward (
360
362
self .forward_func ,
361
363
current_inputs ,
@@ -366,25 +368,34 @@ def attribute(
366
368
if show_progress :
367
369
attr_progress .update ()
368
370
369
- # (contains 1 more dimension than inputs). This adds extra
370
- # dimensions of 1 to make the tensor broadcastable with the inputs
371
- # tensor.
372
371
if not isinstance (modified_eval , torch .Tensor ):
373
372
eval_diff = initial_eval - modified_eval
374
373
else :
375
374
if not agg_output_mode :
375
+ # current_batch_size is not n_examples
376
+ # it may get expanded by n_feature_perturbed
377
+ current_batch_size = current_inputs [0 ].shape [0 ]
376
378
assert (
377
- modified_eval .numel () == current_inputs [ 0 ]. shape [ 0 ]
379
+ modified_eval .numel () == current_batch_size
378
380
), """expected output of forward_func to grow with
379
381
batch_size. If this is not the case for your model
380
382
please set perturbations_per_eval = 1"""
381
383
382
- eval_diff = (
383
- initial_eval - modified_eval .reshape ((- 1 , num_outputs ))
384
- ).reshape ((- 1 , num_outputs ) + (len (inputs [i ].shape ) - 1 ) * (1 ,))
384
+ # reshape the leading dim for n_feature_perturbed
385
+ # flatten each feature's eval outputs into 1D of (n_outputs)
386
+ modified_eval = modified_eval .reshape (- 1 , n_outputs )
387
+ # eval_diff in shape (n_feature_perturbed, n_outputs)
388
+ eval_diff = initial_eval - modified_eval
389
+
390
+ # append the shape of one input example
391
+ # to make it broadcastable to mask
392
+ eval_diff = eval_diff .reshape (
393
+ eval_diff .shape + (inputs [i ].dim () - 1 ) * (1 ,)
394
+ )
385
395
eval_diff = eval_diff .to (total_attrib [i ].device )
386
396
if self .use_weights :
387
397
weights [i ] += current_mask .float ().sum (dim = 0 )
398
+
388
399
total_attrib [i ] += (eval_diff * current_mask .to (attrib_type )).sum (
389
400
dim = 0
390
401
)
0 commit comments