Skip to content

Deprecate and rename conflicting argument names #558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion captum/attr/_core/gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def attribute(
nt, # self
inputs,
nt_type="smoothgrad",
n_samples=n_samples,
nt_samples=n_samples,
stdevs=stdevs,
draw_baseline_from_distrib=True,
baselines=baselines,
Expand Down
12 changes: 7 additions & 5 deletions captum/attr/_core/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from captum._utils.models.linear_model import SkLearnLinearRegression
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.lime import Lime
from captum.attr._utils.common import lime_n_perturb_samples_deprecation_decorator
from captum.log import log_usage


Expand Down Expand Up @@ -72,14 +73,15 @@ def __init__(self, forward_func: Callable) -> None:
)

@log_usage()
@lime_n_perturb_samples_deprecation_decorator
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_perturb_samples: int = 25,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
) -> TensorOrTupleOfTensorsGeneric:
Expand Down Expand Up @@ -213,9 +215,9 @@ def attribute( # type: ignore
If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature.
Default: None
n_perturb_samples (int, optional): The number of samples of the original
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_perturb_samples` is not provided.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
Expand Down Expand Up @@ -266,7 +268,7 @@ def attribute( # type: ignore
>>> ks = KernelShap(net)
>>> # Computes attribution, with each of the 4 x 4 = 16
>>> # features as a separate interpretable feature
>>> attr = ks.attribute(input, target=1, n_perturb_samples=200)
>>> attr = ks.attribute(input, target=1, n_samples=200)

>>> # Alternatively, we can group each 2x2 square of the inputs
>>> # as one 'interpretable' feature and perturb them together.
Expand Down Expand Up @@ -299,7 +301,7 @@ def attribute( # type: ignore
target=target,
additional_forward_args=additional_forward_args,
feature_mask=feature_mask,
n_perturb_samples=n_perturb_samples,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
return_input_shape=return_input_shape,
)
2 changes: 1 addition & 1 deletion captum/attr/_core/layer/layer_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def attribute(
nt, # self
inputs,
nt_type="smoothgrad",
n_samples=n_samples,
nt_samples=n_samples,
stdevs=stdevs,
draw_baseline_from_distrib=True,
baselines=baselines,
Expand Down
27 changes: 14 additions & 13 deletions captum/attr/_core/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from captum.attr._utils.common import (
_construct_default_feature_mask,
_format_input_baseline,
lime_n_perturb_samples_deprecation_decorator,
)
from captum.log import log_usage

Expand Down Expand Up @@ -229,12 +230,13 @@ def __init__(
), "Must provide transform from original input space to interpretable space"

@log_usage()
@lime_n_perturb_samples_deprecation_decorator
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
additional_forward_args: Any = None,
n_perturb_samples: int = 50,
n_samples: int = 50,
perturbations_per_eval: int = 1,
**kwargs
) -> Tensor:
Expand Down Expand Up @@ -308,9 +310,9 @@ def attribute(
Note that attributions are not computed with respect
to these arguments.
Default: None
n_perturb_samples (int, optional): The number of samples of the original
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_perturb_samples` is not provided.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
Expand Down Expand Up @@ -409,7 +411,7 @@ def attribute(
curr_model_inputs = []
expanded_additional_args = None
expanded_target = None
for i in range(n_perturb_samples):
for _ in range(n_samples):
curr_sample = self.perturb_func(inputs, **kwargs)
if self.perturb_interpretable_space:
interpretable_inps.append(curr_sample)
Expand Down Expand Up @@ -479,9 +481,7 @@ def attribute(
dataset = TensorDataset(
combined_interp_inps, combined_outputs, combined_sim
)
self.interpretable_model.fit(
DataLoader(dataset, batch_size=n_perturb_samples)
)
self.interpretable_model.fit(DataLoader(dataset, batch_size=n_samples))
return self.interpretable_model.representation()

def _evaluate_batch(
Expand Down Expand Up @@ -752,14 +752,15 @@ def __init__(
)

@log_usage()
@lime_n_perturb_samples_deprecation_decorator
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_perturb_samples: int = 25,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
) -> TensorOrTupleOfTensorsGeneric:
Expand Down Expand Up @@ -893,9 +894,9 @@ def attribute( # type: ignore
If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature.
Default: None
n_perturb_samples (int, optional): The number of samples of the original
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_perturb_samples` is not provided.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
Expand Down Expand Up @@ -947,7 +948,7 @@ def attribute( # type: ignore
>>> lime = Lime(net)
>>> # Computes attribution, with each of the 4 x 4 = 16
>>> # features as a separate interpretable feature
>>> attr = lime.attribute(input, target=1, n_perturb_samples=200)
>>> attr = lime.attribute(input, target=1, n_samples=200)

>>> # Alternatively, we can group each 2x2 square of the inputs
>>> # as one 'interpretable' feature and perturb them together.
Expand Down Expand Up @@ -1041,7 +1042,7 @@ def attribute( # type: ignore
inputs=curr_inps if is_inputs_tuple else curr_inps[0],
target=curr_target,
additional_forward_args=curr_additional_args,
n_perturb_samples=n_perturb_samples,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
baselines=curr_baselines
if is_inputs_tuple
Expand Down Expand Up @@ -1081,7 +1082,7 @@ def attribute( # type: ignore
inputs=inputs,
target=target,
additional_forward_args=additional_forward_args,
n_perturb_samples=n_perturb_samples,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
baselines=baselines if is_inputs_tuple else baselines[0],
feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
Expand Down
30 changes: 17 additions & 13 deletions captum/attr/_core/noise_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
_is_tuple,
)
from captum.attr._utils.attribution import Attribution, GradientAttribution
from captum.attr._utils.common import _validate_noise_tunnel_type
from captum.attr._utils.common import (
_validate_noise_tunnel_type,
noise_tunnel_n_samples_deprecation_decorator,
)
from captum.log import log_usage


Expand All @@ -30,7 +33,7 @@ class NoiseTunnelType(Enum):

class NoiseTunnel(Attribution):
r"""
Adds gaussian noise to each input in the batch `n_samples` times
Adds gaussian noise to each input in the batch `nt_samples` times
and applies the given attribution algorithm to each of the samples.
The attributions of the samples are combined based on the given noise
tunnel type (nt_type):
Expand Down Expand Up @@ -73,11 +76,12 @@ def multiplies_by_inputs(self):
return self._multiply_by_inputs

@log_usage()
@noise_tunnel_n_samples_deprecation_decorator
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
nt_type: str = "smoothgrad",
n_samples: int = 5,
nt_samples: int = 5,
stdevs: Union[float, Tuple[float, ...]] = 1.0,
draw_baseline_from_distrib: bool = False,
**kwargs: Any,
Expand All @@ -96,10 +100,10 @@ def attribute(
nt_type (string, optional): Smoothing type of the attributions.
`smoothgrad`, `smoothgrad_sq` or `vargrad`
Default: `smoothgrad` if `type` is not provided.
n_samples (int, optional): The number of randomly generated examples
nt_samples (int, optional): The number of randomly generated examples
per sample in the input batch. Random examples are
generated by adding gaussian random noise to each sample.
Default: `5` if `n_samples` is not provided.
Default: `5` if `nt_samples` is not provided.
stdevs (float, or a tuple of floats optional): The standard deviation
of gaussian noise with zero mean that is added to each
input in the batch. If `stdevs` is a single float value
Expand Down Expand Up @@ -154,7 +158,7 @@ def attribute(
>>> # input and averages attributions accros all 10
>>> # perturbed inputs per image
>>> attribution = nt.attribute(input, nt_type='smoothgrad',
>>> n_samples=10, target=3)
>>> nt_samples=10, target=3)
"""

def add_noise_to_inputs() -> Tuple[Tensor, ...]:
Expand Down Expand Up @@ -182,7 +186,7 @@ def add_noise_to_input(input: Tensor, stdev: float) -> Tensor:
bsz = input.shape[0]

# expand input size by the number of drawn samples
input_expanded_size = (bsz * n_samples,) + input.shape[1:]
input_expanded_size = (bsz * nt_samples,) + input.shape[1:]

# expand stdev for the shape of the input and number of drawn samples
stdev_expanded = torch.tensor(stdev, device=input.device).repeat(
Expand All @@ -194,11 +198,11 @@ def add_noise_to_input(input: Tensor, stdev: float) -> Tensor:
# FIXME it look like it is very difficult to make torch.normal
# deterministic this needs an investigation
noise = torch.normal(0, stdev_expanded)
return input.repeat_interleave(n_samples, dim=0) + noise
return input.repeat_interleave(nt_samples, dim=0) + noise

def compute_expected_attribution_and_sq(attribution):
bsz = attribution.shape[0] // n_samples
attribution_shape = (bsz, n_samples)
bsz = attribution.shape[0] // nt_samples
attribution_shape = (bsz, nt_samples)
if len(attribution.shape) > 1:
attribution_shape += attribution.shape[1:]

Expand All @@ -222,11 +226,11 @@ def compute_expected_attribution_and_sq(attribution):
# additional_forward_args they will be expanded based
# on the n_steps and corresponding kwargs
# variables will be updated accordingly
_expand_and_update_additional_forward_args(n_samples, kwargs)
_expand_and_update_target(n_samples, kwargs)
_expand_and_update_additional_forward_args(nt_samples, kwargs)
_expand_and_update_target(nt_samples, kwargs)
_expand_and_update_baselines(
inputs,
n_samples,
nt_samples,
kwargs,
draw_baseline_from_distrib=draw_baseline_from_distrib,
)
Expand Down
53 changes: 49 additions & 4 deletions captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,23 +367,68 @@ def _construct_default_feature_mask(
return feature_mask, total_features


def neuron_index_deprecation_decorator(func):
def neuron_index_deprecation_decorator(func: Callable) -> Callable:
r"""
Decorator to deprecate neuron_index parameter for Neuron Attribution methods.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any):
if "neuron_index" in kwargs:
kwargs["neuron_selector"] = kwargs["neuron_index"]
warnings.warn(
"neuron_index is being deprecated and replaced with neuron_selector "
"to support more general functionality. Please update the parameter "
"to support more general functionality. Please, update the parameter "
"name to neuron_selector. Support for neuron_index will be removed "
"in Captum 0.4.0",
"in Captum 0.4.0.",
DeprecationWarning,
)
del kwargs["neuron_index"]
return func(*args, **kwargs)

return wrapper


def noise_tunnel_n_samples_deprecation_decorator(func: Callable) -> Callable:
r"""
Decorator to depricate n_samples parameter for NoiseTunnel's `attribute` method.
"""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any):
if "nt_samples" not in kwargs and "n_samples" in kwargs:
kwargs["nt_samples"] = kwargs["n_samples"]
warnings.warn(
"n_samples is being deprecated and replaced with nt_samples "
"to avoid argument naming conflicts in the attribute method. "
"Please, update the parameter name to nt_samples. Support for "
"n_samples will be removed in Captum 0.4.0.",
DeprecationWarning,
)
del kwargs["n_samples"]
return func(*args, **kwargs)

return wrapper


def lime_n_perturb_samples_deprecation_decorator(func: Callable) -> Callable:
r"""
Decorator to depricate n_perturb_samples parameter for Lime's and KernelSHAP's
attribute method.
"""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any):
if "n_perturb_samples" in kwargs:
kwargs["n_samples"] = kwargs["n_perturb_samples"]
warnings.warn(
"n_perturb_samples is being deprecated and replaced with n_samples "
"to avoid argument naming conflics in the attribute method. "
"Please, update the parameter name to n_samples. Support for "
"n_perturb_samples will be removed in Captum 0.4.0.",
DeprecationWarning,
)
del kwargs["n_perturb_samples"]
return func(*args, **kwargs)

return wrapper
2 changes: 1 addition & 1 deletion tests/attr/helpers/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@
"inputs": (10 * torch.randn(6, 3), 5 * torch.randn(6, 3)),
"additional_forward_args": (2 * torch.randn(6, 3), 5),
"target": [0, 1, 1, 0, 0, 1],
"n_samples": 20,
"nt_samples": 20,
"stdevs": 0.0,
},
"noise_tunnel": True,
Expand Down
Loading