1
1
#!/usr/bin/env python3
2
2
3
3
import math
4
- from typing import Any , Callable , cast , Tuple , Union
4
+ from typing import Any , Callable , cast , List , Tuple , Union
5
5
6
6
import torch
7
7
from captum ._utils .common import (
19
19
from captum .attr ._utils .common import _format_input_baseline
20
20
from captum .log import log_usage
21
21
from torch import dtype , Tensor
22
+ from torch .futures import Future
22
23
23
24
24
25
class FeatureAblation (PerturbationAttribution ):
@@ -62,6 +63,7 @@ def __init__(self, forward_func: Callable) -> None:
62
63
# input grow as expected. Once it turns to True, we will assume the model's
63
64
# behavior stays consistent and no longer check again
64
65
self ._is_output_shape_valid = False
66
+ self .use_futures = False
65
67
66
68
@log_usage ()
67
69
def attribute (
@@ -286,9 +288,19 @@ def attribute(
286
288
287
289
# Computes initial evaluation with all features, which is compared
288
290
# to each ablated result.
289
- initial_eval = self . _strict_run_forward (
291
+ initial_eval : Union [ Tensor , Future [ Tensor ]] = _run_forward (
290
292
self .forward_func , inputs , target , additional_forward_args
291
293
)
294
+ if self .use_futures :
295
+ assert isinstance (initial_eval , torch .Future ), (
296
+ "when use_futures is True, initial_eval should have "
297
+ f"Future type rather than { type (initial_eval )} "
298
+ )
299
+
300
+ initial_eval .wait ()
301
+ initial_eval = initial_eval .value ()
302
+
303
+ initial_eval = self ._parse_forward_out (initial_eval )
292
304
293
305
if show_progress :
294
306
attr_progress .update ()
@@ -301,7 +313,7 @@ def attribute(
301
313
flattened_initial_eval = initial_eval .reshape (1 , - 1 )
302
314
303
315
# Initialize attribution totals and counts
304
- attrib_type = cast ( dtype , flattened_initial_eval .dtype )
316
+ attrib_type = flattened_initial_eval .dtype
305
317
306
318
total_attrib = [
307
319
# attribute w.r.t each output element
@@ -313,6 +325,7 @@ def attribute(
313
325
for input in inputs
314
326
]
315
327
328
+ weights : List [Tensor ] = []
316
329
# Weights are used in cases where ablations may be overlapping.
317
330
if self .use_weights :
318
331
weights = [
@@ -321,6 +334,7 @@ def attribute(
321
334
).float ()
322
335
for input in inputs
323
336
]
337
+ all_futures = []
324
338
325
339
# Iterate through each feature tensor for ablation
326
340
for i in range (len (inputs )):
@@ -348,7 +362,7 @@ def attribute(
348
362
# agg mode: (*initial_eval.shape)
349
363
# non-agg mode:
350
364
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
351
- modified_eval = self . _strict_run_forward (
365
+ modified_eval = _run_forward (
352
366
self .forward_func ,
353
367
current_inputs ,
354
368
current_target ,
@@ -358,61 +372,62 @@ def attribute(
358
372
if show_progress :
359
373
attr_progress .update ()
360
374
361
- # if perturbations_per_eval > 1, the output shape must grow with
362
- # input and not be aggregated
363
- if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
364
- current_batch_size = current_inputs [0 ].shape [0 ]
365
-
366
- # number of perturbation, which is not the same as
367
- # perturbations_per_eval when not enough features to perturb
368
- n_perturb = current_batch_size / num_examples
369
-
370
- current_output_shape = modified_eval .shape
371
-
372
- # use initial_eval as the forward of perturbations_per_eval = 1
373
- initial_output_shape = initial_eval .shape
374
-
375
- assert (
376
- # check if the output is not a scalar
377
- current_output_shape
378
- and initial_output_shape
379
- # check if the output grow in same ratio, i.e., not agg
380
- and current_output_shape [0 ]
381
- == n_perturb * initial_output_shape [0 ]
382
- ), (
383
- "When perturbations_per_eval > 1, forward_func's output "
384
- "should be a tensor whose 1st dim grow with the input "
385
- f"batch size: when input batch size is { num_examples } , "
386
- f"the output shape is { initial_output_shape } ; "
387
- f"when input batch size is { current_batch_size } , "
388
- f"the output shape is { current_output_shape } "
375
+ if self .use_futures :
376
+ assert isinstance (modified_eval , torch .Future ), (
377
+ "when use_futures is True, modified_eval should have "
378
+ f"Future type rather than { type (modified_eval )} "
379
+ )
380
+ parsed_out_future = modified_eval .then (
381
+ lambda x : self ._parse_forward_out (x .value ())
389
382
)
390
383
391
- self ._is_output_shape_valid = True
392
-
393
- # reshape the leading dim for n_feature_perturbed
394
- # flatten each feature's eval outputs into 1D of (n_outputs)
395
- modified_eval = modified_eval .reshape (- 1 , n_outputs )
396
- # eval_diff in shape (n_feature_perturbed, n_outputs)
397
- eval_diff = flattened_initial_eval - modified_eval
398
-
399
- # append the shape of one input example
400
- # to make it broadcastable to mask
401
- eval_diff = eval_diff .reshape (
402
- eval_diff .shape + (inputs [i ].dim () - 1 ) * (1 ,)
403
- )
404
- eval_diff = eval_diff .to (total_attrib [i ].device )
384
+ all_futures .append (
385
+ parsed_out_future .then (
386
+ lambda modified_eval_future , current_inputs = current_inputs , current_mask = current_mask , i = i : self .process_ablated_out ( # type: ignore # noqa: E501 line too long
387
+ modified_eval_future .value (),
388
+ current_inputs ,
389
+ current_mask ,
390
+ perturbations_per_eval ,
391
+ num_examples ,
392
+ initial_eval ,
393
+ flattened_initial_eval ,
394
+ inputs ,
395
+ n_outputs ,
396
+ total_attrib ,
397
+ weights ,
398
+ i ,
399
+ attrib_type ,
400
+ )
401
+ )
402
+ )
403
+ continue
405
404
406
- if self .use_weights :
407
- weights [i ] += current_mask .float ().sum (dim = 0 )
405
+ modified_eval = self ._parse_forward_out (modified_eval )
408
406
409
- total_attrib [i ] += (eval_diff * current_mask .to (attrib_type )).sum (
410
- dim = 0
407
+ self .process_ablated_out (
408
+ modified_eval ,
409
+ current_inputs ,
410
+ current_mask ,
411
+ perturbations_per_eval ,
412
+ num_examples ,
413
+ initial_eval ,
414
+ flattened_initial_eval ,
415
+ inputs ,
416
+ n_outputs ,
417
+ total_attrib ,
418
+ weights ,
419
+ i ,
420
+ attrib_type ,
411
421
)
412
422
413
423
if show_progress :
414
424
attr_progress .close ()
415
425
426
+ if len (all_futures ) > 0 :
427
+ # torch.futures.Future.wait_all takes list of torch.futures.Future
428
+ # but will cast it to torch._C.Future internally.
429
+ torch .futures .wait_all (cast (List [Future ], all_futures ))
430
+
416
431
# Divide total attributions by counts and return formatted attributions
417
432
if self .use_weights :
418
433
attrib = tuple (
@@ -593,13 +608,12 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs):
593
608
for inp , mask in zip (inputs , feature_mask )
594
609
)
595
610
596
- def _strict_run_forward (self , * args , ** kwargs ) -> Tensor :
611
+ def _parse_forward_out (self , forward_output ) -> Tensor :
597
612
"""
598
613
A temp wrapper for global _run_forward util to force forward output
599
614
type assertion & conversion.
600
615
Remove after the strict logic is supported by all attr classes
601
616
"""
602
- forward_output = _run_forward (* args , ** kwargs )
603
617
if isinstance (forward_output , Tensor ):
604
618
return forward_output
605
619
@@ -612,4 +626,67 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
612
626
# using python built-in type as torch dtype
613
627
# int -> torch.int64, float -> torch.float64
614
628
# ref: https://github.com/pytorch/pytorch/pull/21215
615
- return torch .tensor (forward_output , dtype = output_type )
629
+ return torch .tensor (forward_output , dtype = cast (dtype , output_type ))
630
+
631
+ def process_ablated_out (
632
+ self ,
633
+ modified_eval ,
634
+ current_inputs ,
635
+ current_mask ,
636
+ perturbations_per_eval ,
637
+ num_examples ,
638
+ initial_eval ,
639
+ flattened_initial_eval ,
640
+ inputs ,
641
+ n_outputs ,
642
+ total_attrib ,
643
+ weights ,
644
+ i ,
645
+ attrib_type ,
646
+ ):
647
+ # if perturbations_per_eval > 1, the output shape must grow with
648
+ # input and not be aggregated
649
+ if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
650
+ current_batch_size = current_inputs [0 ].shape [0 ]
651
+
652
+ # number of perturbation, which is not the same as
653
+ # perturbations_per_eval when not enough features to perturb
654
+ n_perturb = current_batch_size / num_examples
655
+
656
+ current_output_shape = modified_eval .shape
657
+
658
+ # use initial_eval as the forward of perturbations_per_eval = 1
659
+ initial_output_shape = initial_eval .shape
660
+
661
+ assert (
662
+ # check if the output is not a scalar
663
+ current_output_shape
664
+ and initial_output_shape
665
+ # check if the output grow in same ratio, i.e., not agg
666
+ and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
667
+ ), (
668
+ "When perturbations_per_eval > 1, forward_func's output "
669
+ "should be a tensor whose 1st dim grow with the input "
670
+ f"batch size: when input batch size is { num_examples } , "
671
+ f"the output shape is { initial_output_shape } ; "
672
+ f"when input batch size is { current_batch_size } , "
673
+ f"the output shape is { current_output_shape } "
674
+ )
675
+
676
+ self ._is_output_shape_valid = True
677
+
678
+ # reshape the leading dim for n_feature_perturbed
679
+ # flatten each feature's eval outputs into 1D of (n_outputs)
680
+ modified_eval = modified_eval .reshape (- 1 , n_outputs )
681
+ # eval_diff in shape (n_feature_perturbed, n_outputs)
682
+ eval_diff = flattened_initial_eval - modified_eval
683
+
684
+ # append the shape of one input example
685
+ # to make it broadcastable to mask
686
+ eval_diff = eval_diff .reshape (eval_diff .shape + (inputs [i ].dim () - 1 ) * (1 ,))
687
+ eval_diff = eval_diff .to (total_attrib [i ].device )
688
+
689
+ if self .use_weights :
690
+ weights [i ] += current_mask .float ().sum (dim = 0 )
691
+
692
+ total_attrib [i ] += (eval_diff * current_mask .to (attrib_type )).sum (dim = 0 )
0 commit comments