Skip to content

Commit c284661

Browse files
yucufacebook-github-bot
authored andcommitted
Initial version of async attribution with torch.futures (#1295)
Summary: Pull Request resolved: #1295 Differential Revision: D56764316
1 parent 3f0cd93 commit c284661

File tree

8 files changed

+229
-65
lines changed

8 files changed

+229
-65
lines changed

captum/_utils/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
TupleOrTensorOrBoolGeneric,
1616
)
1717
from torch import device, Tensor
18+
19+
# from torch.futures import Future
1820
from torch.nn import Module
1921

2022

@@ -514,7 +516,9 @@ def _run_forward(
514516
inputs: Any,
515517
target: TargetType = None,
516518
additional_forward_args: Any = None,
517-
) -> Tensor:
519+
): # Annotate return type to Union[Tensor, Future[Tensor]]
520+
# after PyTorch 1.6.0 support got dropped, otherwise
521+
# it will complain 'pybind11_type' object is not subscriptable
518522
forward_func_args = signature(forward_func).parameters
519523
if len(forward_func_args) == 0:
520524
output = forward_func()
@@ -532,6 +536,8 @@ def _run_forward(
532536
else inputs
533537
)
534538
)
539+
if isinstance(output, torch.futures.Future):
540+
return output.then(lambda x: _select_targets(x.value(), target))
535541
return _select_targets(output, target)
536542

537543

captum/attr/_core/feature_ablation.py

Lines changed: 130 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
import math
4-
from typing import Any, Callable, cast, Tuple, Union
4+
from typing import Any, Callable, cast, List, Tuple, Union
55

66
import torch
77
from captum._utils.common import (
@@ -19,6 +19,7 @@
1919
from captum.attr._utils.common import _format_input_baseline
2020
from captum.log import log_usage
2121
from torch import dtype, Tensor
22+
from torch.futures import Future
2223

2324

2425
class FeatureAblation(PerturbationAttribution):
@@ -62,6 +63,7 @@ def __init__(self, forward_func: Callable) -> None:
6263
# input grow as expected. Once it turns to True, we will assume the model's
6364
# behavior stays consistent and no longer check again
6465
self._is_output_shape_valid = False
66+
self.use_futures = False
6567

6668
@log_usage()
6769
def attribute(
@@ -286,9 +288,19 @@ def attribute(
286288

287289
# Computes initial evaluation with all features, which is compared
288290
# to each ablated result.
289-
initial_eval = self._strict_run_forward(
291+
initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
290292
self.forward_func, inputs, target, additional_forward_args
291293
)
294+
if self.use_futures:
295+
assert isinstance(initial_eval, torch.Future), (
296+
"when use_futures is True, initial_eval should have "
297+
f"Future type rather than {type(initial_eval)}"
298+
)
299+
300+
initial_eval.wait()
301+
initial_eval = initial_eval.value()
302+
303+
initial_eval = self._parse_forward_out(initial_eval)
292304

293305
if show_progress:
294306
attr_progress.update()
@@ -301,7 +313,7 @@ def attribute(
301313
flattened_initial_eval = initial_eval.reshape(1, -1)
302314

303315
# Initialize attribution totals and counts
304-
attrib_type = cast(dtype, flattened_initial_eval.dtype)
316+
attrib_type = flattened_initial_eval.dtype
305317

306318
total_attrib = [
307319
# attribute w.r.t each output element
@@ -313,6 +325,7 @@ def attribute(
313325
for input in inputs
314326
]
315327

328+
weights: List[Tensor] = []
316329
# Weights are used in cases where ablations may be overlapping.
317330
if self.use_weights:
318331
weights = [
@@ -321,6 +334,7 @@ def attribute(
321334
).float()
322335
for input in inputs
323336
]
337+
all_futures = []
324338

325339
# Iterate through each feature tensor for ablation
326340
for i in range(len(inputs)):
@@ -348,7 +362,7 @@ def attribute(
348362
# agg mode: (*initial_eval.shape)
349363
# non-agg mode:
350364
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
351-
modified_eval = self._strict_run_forward(
365+
modified_eval = _run_forward(
352366
self.forward_func,
353367
current_inputs,
354368
current_target,
@@ -358,61 +372,62 @@ def attribute(
358372
if show_progress:
359373
attr_progress.update()
360374

361-
# if perturbations_per_eval > 1, the output shape must grow with
362-
# input and not be aggregated
363-
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
364-
current_batch_size = current_inputs[0].shape[0]
365-
366-
# number of perturbation, which is not the same as
367-
# perturbations_per_eval when not enough features to perturb
368-
n_perturb = current_batch_size / num_examples
369-
370-
current_output_shape = modified_eval.shape
371-
372-
# use initial_eval as the forward of perturbations_per_eval = 1
373-
initial_output_shape = initial_eval.shape
374-
375-
assert (
376-
# check if the output is not a scalar
377-
current_output_shape
378-
and initial_output_shape
379-
# check if the output grow in same ratio, i.e., not agg
380-
and current_output_shape[0]
381-
== n_perturb * initial_output_shape[0]
382-
), (
383-
"When perturbations_per_eval > 1, forward_func's output "
384-
"should be a tensor whose 1st dim grow with the input "
385-
f"batch size: when input batch size is {num_examples}, "
386-
f"the output shape is {initial_output_shape}; "
387-
f"when input batch size is {current_batch_size}, "
388-
f"the output shape is {current_output_shape}"
375+
if self.use_futures:
376+
assert isinstance(modified_eval, torch.Future), (
377+
"when use_futures is True, modified_eval should have "
378+
f"Future type rather than {type(modified_eval)}"
379+
)
380+
parsed_out_future = modified_eval.then(
381+
lambda x: self._parse_forward_out(x.value())
389382
)
390383

391-
self._is_output_shape_valid = True
392-
393-
# reshape the leading dim for n_feature_perturbed
394-
# flatten each feature's eval outputs into 1D of (n_outputs)
395-
modified_eval = modified_eval.reshape(-1, n_outputs)
396-
# eval_diff in shape (n_feature_perturbed, n_outputs)
397-
eval_diff = flattened_initial_eval - modified_eval
398-
399-
# append the shape of one input example
400-
# to make it broadcastable to mask
401-
eval_diff = eval_diff.reshape(
402-
eval_diff.shape + (inputs[i].dim() - 1) * (1,)
403-
)
404-
eval_diff = eval_diff.to(total_attrib[i].device)
384+
all_futures.append(
385+
parsed_out_future.then(
386+
lambda modified_eval_future, current_inputs=current_inputs, current_mask=current_mask, i=i: self.process_ablated_out( # type: ignore # noqa: E501 line too long
387+
modified_eval_future.value(),
388+
current_inputs,
389+
current_mask,
390+
perturbations_per_eval,
391+
num_examples,
392+
initial_eval,
393+
flattened_initial_eval,
394+
inputs,
395+
n_outputs,
396+
total_attrib,
397+
weights,
398+
i,
399+
attrib_type,
400+
)
401+
)
402+
)
403+
continue
405404

406-
if self.use_weights:
407-
weights[i] += current_mask.float().sum(dim=0)
405+
modified_eval = self._parse_forward_out(modified_eval)
408406

409-
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
410-
dim=0
407+
self.process_ablated_out(
408+
modified_eval,
409+
current_inputs,
410+
current_mask,
411+
perturbations_per_eval,
412+
num_examples,
413+
initial_eval,
414+
flattened_initial_eval,
415+
inputs,
416+
n_outputs,
417+
total_attrib,
418+
weights,
419+
i,
420+
attrib_type,
411421
)
412422

413423
if show_progress:
414424
attr_progress.close()
415425

426+
if len(all_futures) > 0:
427+
# torch.futures.Future.wait_all takes list of torch.futures.Future
428+
# but will cast it to torch._C.Future internally.
429+
torch.futures.wait_all(cast(List[Future], all_futures))
430+
416431
# Divide total attributions by counts and return formatted attributions
417432
if self.use_weights:
418433
attrib = tuple(
@@ -593,13 +608,12 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs):
593608
for inp, mask in zip(inputs, feature_mask)
594609
)
595610

596-
def _strict_run_forward(self, *args, **kwargs) -> Tensor:
611+
def _parse_forward_out(self, forward_output) -> Tensor:
597612
"""
598613
A temp wrapper for global _run_forward util to force forward output
599614
type assertion & conversion.
600615
Remove after the strict logic is supported by all attr classes
601616
"""
602-
forward_output = _run_forward(*args, **kwargs)
603617
if isinstance(forward_output, Tensor):
604618
return forward_output
605619

@@ -612,4 +626,67 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
612626
# using python built-in type as torch dtype
613627
# int -> torch.int64, float -> torch.float64
614628
# ref: https://github.com/pytorch/pytorch/pull/21215
615-
return torch.tensor(forward_output, dtype=output_type)
629+
return torch.tensor(forward_output, dtype=cast(dtype, output_type))
630+
631+
def process_ablated_out(
632+
self,
633+
modified_eval,
634+
current_inputs,
635+
current_mask,
636+
perturbations_per_eval,
637+
num_examples,
638+
initial_eval,
639+
flattened_initial_eval,
640+
inputs,
641+
n_outputs,
642+
total_attrib,
643+
weights,
644+
i,
645+
attrib_type,
646+
):
647+
# if perturbations_per_eval > 1, the output shape must grow with
648+
# input and not be aggregated
649+
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
650+
current_batch_size = current_inputs[0].shape[0]
651+
652+
# number of perturbation, which is not the same as
653+
# perturbations_per_eval when not enough features to perturb
654+
n_perturb = current_batch_size / num_examples
655+
656+
current_output_shape = modified_eval.shape
657+
658+
# use initial_eval as the forward of perturbations_per_eval = 1
659+
initial_output_shape = initial_eval.shape
660+
661+
assert (
662+
# check if the output is not a scalar
663+
current_output_shape
664+
and initial_output_shape
665+
# check if the output grow in same ratio, i.e., not agg
666+
and current_output_shape[0] == n_perturb * initial_output_shape[0]
667+
), (
668+
"When perturbations_per_eval > 1, forward_func's output "
669+
"should be a tensor whose 1st dim grow with the input "
670+
f"batch size: when input batch size is {num_examples}, "
671+
f"the output shape is {initial_output_shape}; "
672+
f"when input batch size is {current_batch_size}, "
673+
f"the output shape is {current_output_shape}"
674+
)
675+
676+
self._is_output_shape_valid = True
677+
678+
# reshape the leading dim for n_feature_perturbed
679+
# flatten each feature's eval outputs into 1D of (n_outputs)
680+
modified_eval = modified_eval.reshape(-1, n_outputs)
681+
# eval_diff in shape (n_feature_perturbed, n_outputs)
682+
eval_diff = flattened_initial_eval - modified_eval
683+
684+
# append the shape of one input example
685+
# to make it broadcastable to mask
686+
eval_diff = eval_diff.reshape(eval_diff.shape + (inputs[i].dim() - 1) * (1,))
687+
eval_diff = eval_diff.to(total_attrib[i].device)
688+
689+
if self.use_weights:
690+
weights[i] += current_mask.float().sum(dim=0)
691+
692+
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(dim=0)

captum/attr/_core/layer/layer_feature_permutation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ def forward_hook(module, inp, out=None):
197197
finally:
198198
if hook is not None:
199199
hook.remove()
200+
201+
if isinstance(eval, torch.futures.Future):
202+
eval.wait()
203+
eval = eval.value()
204+
200205
return eval
201206

202207
with torch.no_grad():

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,10 @@ def layer_forward_hook(
474474
if hook is not None:
475475
hook.remove()
476476

477+
if isinstance(output, torch.futures.Future):
478+
output.wait()
479+
output = output.value()
480+
477481
assert output[0].numel() == 1, (
478482
"Target not provided when necessary, cannot"
479483
" take gradient with respect to multiple outputs."

captum/attr/_core/lrp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections import defaultdict
55
from typing import Any, cast, List, Tuple, Union
66

7+
import torch
8+
79
import torch.nn as nn
810
from captum._utils.common import (
911
_format_output,
@@ -358,6 +360,10 @@ def _compute_output_and_change_weights(
358360
# adjustments as inputs to the layers with adjusted weights. This procedure
359361
# is important for graph generation in the 2nd forward pass.
360362
self._register_pre_hooks()
363+
364+
if isinstance(output, torch.futures.Future):
365+
output.wait()
366+
output = output.value()
361367
return output
362368

363369
def _remove_forward_hooks(self) -> None:

captum/attr/_core/shapley_value.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import math
55
import warnings
6-
from typing import Any, Callable, Iterable, Sequence, Tuple, Union
6+
from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union
77

88
import torch
99
from captum._utils.common import (
@@ -27,7 +27,7 @@
2727
_tensorize_baseline,
2828
)
2929
from captum.log import log_usage
30-
from torch import Tensor
30+
from torch import dtype, Tensor
3131

3232

3333
def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]:
@@ -551,7 +551,7 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
551551
# using python built-in type as torch dtype
552552
# int -> torch.int64, float -> torch.float64
553553
# ref: https://github.com/pytorch/pytorch/pull/21215
554-
return torch.tensor([forward_output], dtype=output_type)
554+
return torch.tensor([forward_output], dtype=cast(dtype, output_type))
555555

556556

557557
class ShapleyValues(ShapleyValueSampling):

0 commit comments

Comments
 (0)