Skip to content

Neuron Conductance Tuple Support #602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions captum/attr/_core/neuron/neuron_conductance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
import warnings
from typing import Any, Callable, List, Tuple, Union

import torch
Expand All @@ -11,7 +12,7 @@
_format_additional_forward_args,
_format_output,
_is_tuple,
_verify_select_column,
_verify_select_neuron,
)
from captum._utils.gradient import compute_layer_gradients_and_eval
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
neuron_selector: Union[int, Tuple[int, ...]],
neuron_selector: Union[int, Tuple[int, ...], Callable],
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
Expand All @@ -113,13 +114,38 @@ def attribute(
that for all given input tensors, dimension 0 corresponds
to the number of examples, and if multiple input tensors
are provided, the examples must be aligned appropriately.
neuron_selector (int or tuple): Index of neuron in output of given
layer for which attribution is desired. Length of
this tuple must be one less than the number of
dimensions in the output of the given layer (since
dimension 0 corresponds to number of examples).
An integer may be provided instead of a tuple of
length 1.
neuron_selector (int, callable, or tuple of ints or slices):
Selector for neuron
in given layer for which attribution is desired.
Neuron selector can be provided as:

- a single integer, if the layer output is 2D. This integer
selects the appropriate neuron column in the layer input
or output

- a tuple of integers. Length of this
tuple must be one less than the number of dimensions
in the input / output of the given layer (since
dimension 0 corresponds to number of examples).
This can be used as long as the layer input / output
is a single tensor.

- a callable, which should
take the target layer as input (single tensor or tuple
if multiple tensors are in layer) and return a selected
neuron - output shape should be 1D with length equal to
batch_size (one scalar per input example)

NOTE: Callables applicable for neuron conductance are
less general than those of other methods and should
NOT aggregate values of the layer, only return a specific
output. This option should only be used in cases where the
layer input / output is a tuple of tensors, where the other
options would not suffice. This limitation is necessary since
neuron conductance, unlike other neuron methods, also utilizes
the gradient of output with respect to the intermedite neuron,
which cannot be computed for aggregations of multiple
intemediate neurons.
baselines (scalar, tensor, tuple of scalars or tensors, optional):
Baselines define the starting point from which integral
is computed and can be provided as:
Expand Down Expand Up @@ -249,6 +275,13 @@ def attribute(
>>> # index (4,1,2).
>>> attribution = neuron_cond.attribute(input, (4,1,2))
"""
if callable(neuron_selector):
warnings.warn(
"The neuron_selector provided is a callable. Please ensure that this"
" function only selects neurons from the given layer; aggregating"
" or performing other operations on the tensor may lead to inaccurate"
" results."
)
is_inputs_tuple = _is_tuple(inputs)

inputs, baselines = _format_input_baseline(inputs, baselines)
Expand Down Expand Up @@ -287,7 +320,7 @@ def attribute(
def _attribute(
self,
inputs: Tuple[Tensor, ...],
neuron_selector: Union[int, Tuple[int, ...]],
neuron_selector: Union[int, Tuple[int, ...], Callable],
baselines: Tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
additional_forward_args: Any = None,
Expand Down Expand Up @@ -343,18 +376,7 @@ def _attribute(
attribute_to_layer_input=attribute_to_neuron_input,
)

# Layer gradients and eval
assert (
len(layer_gradients) == 1 and len(layer_eval) == 1
), "Layer can only have 1 output tensor for neuron attribution!"
layer_gradients = layer_gradients[0]
layer_eval = layer_eval[0]

# Multiplies by appropriate gradient of output with respect to hidden neurons
# mid_grads is a 1D Tensor of length num_steps*internal_batch_size,
# containing mid layer gradient for each input step.
mid_grads = _verify_select_column(layer_gradients, neuron_selector)

mid_grads = _verify_select_neuron(layer_gradients, neuron_selector)
scaled_input_gradients = tuple(
input_grad
* mid_grads.reshape((total_batch,) + (1,) * (len(input_grad.shape) - 1))
Expand Down
42 changes: 39 additions & 3 deletions tests/attr/neuron/test_neuron_conductance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

import unittest
from typing import Any, List, Tuple, Union, cast
from typing import Any, Callable, List, Tuple, Union, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -43,6 +43,13 @@ def test_simple_conductance_input_linear1(self) -> None:
inp = torch.tensor([[0.0, 100.0, 0.0]])
self._conductance_input_test_assert(net, net.linear1, inp, 0, [0.0, 90.0, 0.0])

def test_simple_conductance_input_linear1_selector_fn(self) -> None:
net = BasicModel_MultiLayer()
inp = torch.tensor([[0.0, 100.0, 0.0]])
self._conductance_input_test_assert(
net, net.linear1, inp, lambda x: x[:, 0], [0.0, 90.0, 0.0]
)

Copy link
Contributor

@NarineK NarineK Feb 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a test model that has a custom layer which returns a tuple, could we, please, add a test for that too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think both the other 2 added tests cover this case. Passing multi_input_module=True uses a custom layer MultiRelu, which returns a tuple.

def test_simple_conductance_input_relu(self) -> None:
net = BasicModel_MultiLayer()
inp = torch.tensor([[0.0, 70.0, 30.0]], requires_grad=True)
Expand Down Expand Up @@ -93,6 +100,13 @@ def test_simple_conductance_multi_input_batch_relu(self) -> None:
(inp3, 5),
)

def test_layer_tuple_selector_fn(self) -> None:
net = BasicModel_MultiLayer(multi_input_module=True)
inp = torch.tensor([[0.0, 6.0, 0.0]])
self._conductance_input_test_assert(
net, net.multi_relu, inp, lambda x: x[0][:, 1], [0.0, 6.0, 0.0]
)

def test_matching_conv2_multi_input_conductance(self) -> None:
net = BasicModel_ConvNet()
inp = 100 * torch.randn(2, 1, 10, 10)
Expand All @@ -118,12 +132,34 @@ def test_matching_pool2_multi_input_conductance(self) -> None:
baseline = 20 * torch.randn(1, 1, 10, 10, requires_grad=True)
self._conductance_input_sum_test_assert(net, net.pool2, inp, baseline)

def test_matching_layer_tuple_selector_fn(self) -> None:
net = BasicModel_MultiLayer(multi_input_module=True)
inp = torch.tensor([[0.0, 6.0, 0.0]])

lc = LayerConductance(net, net.multi_relu)
layer_attr = lc.attribute(inp, target=0, n_steps=500, method="gausslegendre")
nc = NeuronConductance(net, net.multi_relu)
for i in range(len(layer_attr)):
for j in range(layer_attr[i].shape[1]):
neuron_attr = nc.attribute(
inp,
lambda x: x[i][:, j],
target=0,
n_steps=500,
method="gausslegendre",
)
self.assertAlmostEqual(
neuron_attr.sum().item(),
layer_attr[i][0][j].item(),
delta=0.005,
)

def _conductance_input_test_assert(
self,
model: Module,
target_layer: Module,
test_input: TensorOrTupleOfTensorsGeneric,
test_neuron: Union[int, Tuple[int, ...]],
test_neuron: Union[int, Tuple[int, ...], Callable],
expected_input_conductance: Union[List[float], Tuple[List[List[float]], ...]],
additional_input: Any = None,
multiply_by_inputs: bool = True,
Expand All @@ -134,7 +170,7 @@ def _conductance_input_test_assert(
target_layer,
multiply_by_inputs=multiply_by_inputs,
)
self.assertEquals(cond.multiplies_by_inputs, multiply_by_inputs)
self.assertEqual(cond.multiplies_by_inputs, multiply_by_inputs)
attributions = cond.attribute(
test_input,
test_neuron,
Expand Down