@@ -286,21 +286,22 @@ def attribute(
286
286
if show_progress :
287
287
attr_progress .update ()
288
288
289
+ # number of elements in the output of forward_func
290
+ n_outputs = initial_eval .numel () if isinstance (initial_eval , Tensor ) else 1
291
+
292
+ # flattent eval outputs into 1D (n_outputs)
293
+ # add the leading dim for n_feature_perturbed
294
+ if isinstance (initial_eval , Tensor ):
295
+ initial_eval = initial_eval .reshape (1 , - 1 )
296
+
289
297
agg_output_mode = FeatureAblation ._find_output_mode (
290
298
perturbations_per_eval , feature_mask
291
299
)
292
300
293
- # get as a 2D tensor (if it is not a scalar)
294
- if isinstance (initial_eval , torch .Tensor ):
295
- initial_eval = initial_eval .reshape (1 , - 1 )
296
- num_outputs = initial_eval .shape [1 ]
297
- else :
298
- num_outputs = 1
299
-
300
301
if not agg_output_mode :
301
302
assert (
302
- isinstance (initial_eval , torch . Tensor )
303
- and num_outputs == num_examples
303
+ isinstance (initial_eval , Tensor )
304
+ and n_outputs == num_examples
304
305
), (
305
306
"expected output of `forward_func` to have "
306
307
+ "`batch_size` elements for perturbations_per_eval > 1 "
@@ -316,8 +317,9 @@ def attribute(
316
317
)
317
318
318
319
total_attrib = [
320
+ # attribute w.r.t each output element
319
321
torch .zeros (
320
- (num_outputs ,) + input .shape [1 :],
322
+ (n_outputs , * input .shape [1 :]) ,
321
323
dtype = attrib_type ,
322
324
device = input .device ,
323
325
)
@@ -328,7 +330,7 @@ def attribute(
328
330
if self .use_weights :
329
331
weights = [
330
332
torch .zeros (
331
- (num_outputs ,) + input .shape [1 :], device = input .device
333
+ (n_outputs , * input .shape [1 :]) , device = input .device
332
334
).float ()
333
335
for input in inputs
334
336
]
@@ -354,8 +356,11 @@ def attribute(
354
356
perturbations_per_eval ,
355
357
** kwargs ,
356
358
):
357
- # modified_eval dimensions: 1D tensor with length
358
- # equal to #num_examples * #features in batch
359
+ # modified_eval has (n_feature_perturbed * n_outputs) elements
360
+ # shape:
361
+ # agg mode: (*initial_eval.shape)
362
+ # non-agg mode:
363
+ # (feature_perturbed * batch_size, *initial_eval.shape[1:])
359
364
modified_eval = _run_forward (
360
365
self .forward_func ,
361
366
current_inputs ,
@@ -366,25 +371,34 @@ def attribute(
366
371
if show_progress :
367
372
attr_progress .update ()
368
373
369
- # (contains 1 more dimension than inputs). This adds extra
370
- # dimensions of 1 to make the tensor broadcastable with the inputs
371
- # tensor.
372
374
if not isinstance (modified_eval , torch .Tensor ):
373
375
eval_diff = initial_eval - modified_eval
374
376
else :
375
377
if not agg_output_mode :
378
+ # current_batch_size is not n_examples
379
+ # it may get expanded by n_feature_perturbed
380
+ current_batch_size = current_inputs [0 ].shape [0 ]
376
381
assert (
377
- modified_eval .numel () == current_inputs [ 0 ]. shape [ 0 ]
382
+ modified_eval .numel () == current_batch_size
378
383
), """expected output of forward_func to grow with
379
384
batch_size. If this is not the case for your model
380
385
please set perturbations_per_eval = 1"""
381
386
382
- eval_diff = (
383
- initial_eval - modified_eval .reshape ((- 1 , num_outputs ))
384
- ).reshape ((- 1 , num_outputs ) + (len (inputs [i ].shape ) - 1 ) * (1 ,))
387
+ # reshape the leading dim for n_feature_perturbed
388
+ # flatten each feature's eval outputs into 1D of (n_outputs)
389
+ modified_eval = modified_eval .reshape (- 1 , n_outputs )
390
+ # eval_diff in shape (n_feature_perturbed, n_outputs)
391
+ eval_diff = initial_eval - modified_eval
392
+
393
+ # append the shape of one input example
394
+ # to make it broadcastable to mask
395
+ eval_diff = eval_diff .reshape (
396
+ eval_diff .shape + (inputs [i ].dim () - 1 ) * (1 ,)
397
+ )
385
398
eval_diff = eval_diff .to (total_attrib [i ].device )
386
399
if self .use_weights :
387
400
weights [i ] += current_mask .float ().sum (dim = 0 )
401
+
388
402
total_attrib [i ] += (eval_diff * current_mask .to (attrib_type )).sum (
389
403
dim = 0
390
404
)
0 commit comments