Skip to content

Commit 0d9ebb4

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
update tracin influence API (#1072)
Summary: Pull Request resolved: #1072 This diff changes the API for implementations of `TracInCPBase` as discussed in https://fb.quip.com/JbpnAiWluZmI. In particular, the arguments representing test data of the `influence` method are changed from `inputs: Tuple, targets: Optional[Tensor]` to `inputs: Union[Tuple[Any], DataLoader]`, which is either a single batch, or a dataloader yielding batches. In both cases, `model(*batch)` is assumed to produce the predictions for a batch, and `batch[-1]` is assumed to be the labels for a batch. This is the same format assumed of the batches yielded by `train_dataloader`. We make this change for 2 reasons - it unifies the assumptions made of the test data and the assumptions made of the training data - for some implementations, we want to allow the test data to be represented by a dataloader. with the old API, there was no clean way to allow both a single as well as a dataloader to be passed in, since a batch required 2 arguments, but a dataloader only requires 1. For now, all implementations only allow `inputs` to be a tuple (and not a dataloader). This is okay due to inheritance rules. Later on, we will allow some implementations (i.e. `TracInCP`) to accept a dataloader as `inputs`. Other changes: - changes to make documentation. for example, documentation in `TracInCPBase.influence` now refers to the "test dataset" instead of test batch. - the `unpack_inputs` argument is no longer needed for the `influence` methods, and is removed - the usage of `influence` in all the tests is changed to match new API. - signature of helper methods `_influence_batch_tracincp` and `_influence_batch_tracincp_fast` are changed to match new representation of batches. Reviewed By: cyrjano Differential Revision: D41324297 fbshipit-source-id: 9fe108de1a6789d461c19d71b724cd18bbcffbd9
1 parent e9eeac4 commit 0d9ebb4

14 files changed

+386
-492
lines changed

captum/influence/_core/influence.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,19 @@ class DataInfluence(ABC):
1212
An abstract class to define model data influence skeleton.
1313
"""
1414

15-
def __init_(
16-
self, model: Module, influence_src_dataset: Dataset, **kwargs: Any
17-
) -> None:
15+
def __init_(self, model: Module, train_dataset: Dataset, **kwargs: Any) -> None:
1816
r"""
1917
Args:
2018
model (torch.nn.Module): An instance of pytorch model.
21-
influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
19+
train_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
2220
used to create a PyTorch Dataloader to iterate over the dataset and
2321
its labels. This is the dataset for which we will be seeking for
2422
influential instances. In most cases this is the training dataset.
2523
**kwargs: Additional key-value arguments that are necessary for specific
2624
implementation of `DataInfluence` abstract class.
2725
"""
2826
self.model = model
29-
self.influence_src_dataset = influence_src_dataset
27+
self.train_dataset = train_dataset
3028

3129
@abstractmethod
3230
def influence(self, inputs: Any = None, **kwargs: Any) -> Any:

captum/influence/_core/tracincp.py

Lines changed: 153 additions & 190 deletions
Large diffs are not rendered by default.

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 149 additions & 215 deletions
Large diffs are not rendered by default.

captum/influence/_utils/common.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def _get_k_most_influential_helper(
191191
influence_src_dataloader: DataLoader,
192192
influence_batch_fn: Callable,
193193
inputs: Tuple[Any, ...],
194-
targets: Optional[Tensor],
195194
k: int = 5,
196195
proponents: bool = True,
197196
show_progress: bool = False,
@@ -206,13 +205,12 @@ def _get_k_most_influential_helper(
206205
influence_src_dataloader (DataLoader): The DataLoader, representing training
207206
data, for which we want to compute proponents / opponents.
208207
influence_batch_fn (Callable): A callable that will be called via
209-
`influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch
208+
`influence_batch_fn(inputs, batch)`, where `batch` is a batch
210209
in the `influence_src_dataloader` argument.
211-
inputs (tuple[Any, ...]): A batch of examples. Does not represent labels,
212-
which are passed as `targets`.
213-
targets (Tensor, optional): If computing TracIn scores on a loss function,
214-
these are the labels corresponding to the batch `inputs`.
215-
Default: None
210+
inputs (tuple[Any, ...]): This argument represents the test batch, and is a
211+
single tuple of any, where the last element is assumed to be the labels
212+
for the batch. That is, `model(*batch[0:-1])` produces the output for
213+
`model`, and `batch[-1]` are the labels, if any.
216214
k (int, optional): The number of proponents or opponents to return per test
217215
instance.
218216
Default: 5
@@ -274,7 +272,7 @@ def _get_k_most_influential_helper(
274272
for batch in influence_src_dataloader:
275273

276274
# calculate tracin_scores for the batch
277-
batch_tracin_scores = influence_batch_fn(inputs, targets, batch)
275+
batch_tracin_scores = influence_batch_fn(inputs, batch)
278276
batch_tracin_scores *= multiplier
279277

280278
# get the top-k indices and tracin_scores for the batch

tests/influence/_core/test_dataloader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from parameterized import parameterized
1111
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1212
from tests.influence._utils.common import (
13+
_format_batch_into_tuple,
1314
build_test_name_func,
1415
DataInfluenceConstructor,
1516
get_random_model_and_data,
@@ -76,7 +77,8 @@ def test_tracin_dataloader(
7677
)
7778

7879
train_scores = tracin.influence(
79-
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
80+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
81+
k=None,
8082
)
8183

8284
tracin_dataloader = tracin_constructor(
@@ -88,7 +90,8 @@ def test_tracin_dataloader(
8890
)
8991

9092
train_scores_dataloader = tracin_dataloader.influence(
91-
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
93+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
94+
k=None,
9295
)
9396

9497
assertTensorAlmostEqual(

tests/influence/_core/test_tracin_intermediate_quantities.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from parameterized import parameterized
1313
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1414
from tests.influence._utils.common import (
15+
_format_batch_into_tuple,
1516
build_test_name_func,
1617
DataInfluenceConstructor,
1718
get_random_model_and_data,
@@ -224,25 +225,13 @@ def test_tracin_intermediate_quantities_consistent(
224225
)
225226

226227
# compute influence scores without using `compute_intermediate_quantities`
228+
test_batch = _format_batch_into_tuple(
229+
test_features, test_labels, unpack_inputs
230+
)
227231
scores = tracin.influence(
228-
test_features, test_labels, unpack_inputs=unpack_inputs
232+
test_batch,
229233
)
230234

231-
# compute influence scores using `compute_intermediate_quantities`
232-
# we combine `test_features` and `test_labels` into a single tuple
233-
# `test_batch` to pass to the model, with the assumption that
234-
# `model(test_batch[0:-1]` produces the predictions, and `test_batch[-1]`
235-
# are the labels. We do this due to the assumptions made by the
236-
# `compute_intermediate_quantities` method. Therefore, how we
237-
# form `test_batch` depends on whether `unpack_inputs` is True or False
238-
if not unpack_inputs:
239-
# `test_features` is a Tensor
240-
test_batch = (test_features, test_labels)
241-
else:
242-
# `test_features` is a tuple, so we unpack it to place in tuple,
243-
# along with `test_labels`
244-
test_batch = (*test_features, test_labels) # type: ignore[assignment]
245-
246235
# the influence score is the dot product of intermediate quantities
247236
intermediate_quantities_scores = torch.matmul(
248237
intermediate_quantities_tracin.compute_intermediate_quantities(

tests/influence/_core/test_tracin_k_most_influential.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parameterized import parameterized
99
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1010
from tests.influence._utils.common import (
11+
_format_batch_into_tuple,
1112
build_test_name_func,
1213
DataInfluenceConstructor,
1314
get_random_model_and_data,
@@ -107,15 +108,14 @@ def test_tracin_k_most_influential(
107108
)
108109

109110
train_scores = tracin.influence(
110-
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
111+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
112+
k=None,
111113
)
112114
sort_idx = torch.argsort(train_scores, dim=1, descending=proponents)[:, 0:k]
113115
idx, _train_scores = tracin.influence(
114-
test_samples,
115-
test_labels,
116+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
116117
k=k,
117118
proponents=proponents,
118-
unpack_inputs=unpack_inputs,
119119
)
120120
for i in range(len(idx)):
121121
# check that idx[i] is correct

tests/influence/_core/test_tracin_regression.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,19 @@ def test_tracin_regression(
183183
criterion,
184184
)
185185

186-
train_scores = tracin.influence(train_inputs, train_labels)
186+
train_scores = tracin.influence((train_inputs, train_labels))
187187
idx, _ = tracin.influence(
188-
train_inputs, train_labels, k=len(dataset), proponents=True
188+
(train_inputs, train_labels), k=len(dataset), proponents=True
189189
)
190190
# check that top influence is one with maximal value
191191
# (and hence gradient)
192192
for i in range(len(idx)):
193193
self.assertEqual(idx[i][0], 15)
194194

195195
# check influence scores of test data
196-
test_scores = tracin.influence(test_inputs, test_labels)
196+
test_scores = tracin.influence((test_inputs, test_labels))
197197
idx, _ = tracin.influence(
198-
test_inputs, test_labels, k=len(test_inputs), proponents=True
198+
(test_inputs, test_labels), k=len(test_inputs), proponents=True
199199
)
200200
# check that top influence is one with maximal value
201201
# (and hence gradient)
@@ -226,17 +226,17 @@ def test_tracin_regression(
226226
sample_wise_grads_per_batch=True,
227227
)
228228

229-
train_scores = tracin.influence(train_inputs, train_labels)
229+
train_scores = tracin.influence((train_inputs, train_labels))
230230
train_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
231-
train_inputs, train_labels
231+
(train_inputs, train_labels)
232232
)
233233
assertTensorAlmostEqual(
234234
self, train_scores, train_scores_sample_wise_trick
235235
)
236236

237-
test_scores = tracin.influence(test_inputs, test_labels)
237+
test_scores = tracin.influence((test_inputs, test_labels))
238238
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
239-
test_inputs, test_labels
239+
(test_inputs, test_labels)
240240
)
241241
assertTensorAlmostEqual(
242242
self, test_scores, test_scores_sample_wise_trick
@@ -288,7 +288,7 @@ def test_tracin_regression_1D_numerical(
288288
criterion,
289289
)
290290

291-
train_scores = tracin.influence(train_inputs, train_labels, k=None)
291+
train_scores = tracin.influence((train_inputs, train_labels), k=None)
292292

293293
r"""
294294
Derivation for gradient / resulting TracIn score:
@@ -382,9 +382,9 @@ def test_tracin_identity_regression(
382382

383383
# check influence scores of training data
384384

385-
train_scores = tracin.influence(train_inputs, train_labels)
385+
train_scores = tracin.influence((train_inputs, train_labels))
386386
idx, _ = tracin.influence(
387-
train_inputs, train_labels, k=len(dataset), proponents=True
387+
(train_inputs, train_labels), k=len(dataset), proponents=True
388388
)
389389

390390
# check that top influence for an instance is itself
@@ -415,9 +415,9 @@ def test_tracin_identity_regression(
415415
sample_wise_grads_per_batch=True,
416416
)
417417

418-
train_scores = tracin.influence(train_inputs, train_labels)
418+
train_scores = tracin.influence((train_inputs, train_labels))
419419
train_scores_tracin_sample_wise_trick = (
420-
tracin_sample_wise_trick.influence(train_inputs, train_labels)
420+
tracin_sample_wise_trick.influence((train_inputs, train_labels))
421421
)
422422
assertTensorAlmostEqual(
423423
self, train_scores, train_scores_tracin_sample_wise_trick
@@ -496,5 +496,5 @@ def test_loss_fn(input, target):
496496
)
497497

498498
# check influence scores of training data. they should all be 0
499-
train_scores = tracin.influence(train_inputs, train_labels, k=None)
499+
train_scores = tracin.influence((train_inputs, train_labels), k=None)
500500
assertTensorAlmostEqual(self, train_scores, torch.zeros(train_scores.shape))

tests/influence/_core/test_tracin_self_influence.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parameterized import parameterized
99
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1010
from tests.influence._utils.common import (
11+
_format_batch_into_tuple,
1112
build_test_name_func,
1213
DataInfluenceConstructor,
1314
get_random_model_and_data,
@@ -108,10 +109,10 @@ def test_tracin_self_influence(
108109
criterion,
109110
)
110111
train_scores = tracin.influence(
111-
train_dataset.samples,
112-
train_dataset.labels,
112+
_format_batch_into_tuple(
113+
train_dataset.samples, train_dataset.labels, unpack_inputs
114+
),
113115
k=None,
114-
unpack_inputs=unpack_inputs,
115116
)
116117
# calculate self_tracin_scores
117118
self_tracin_scores = tracin.self_influence(

tests/influence/_core/test_tracin_show_progress.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ def test_tracin_show_progress(
178178
elif mode == "influence":
179179

180180
tracin.influence(
181-
test_samples,
182-
test_labels,
181+
(test_samples, test_labels),
183182
k=None,
184183
show_progress=True,
185184
)
@@ -196,8 +195,7 @@ def test_tracin_show_progress(
196195
elif mode == "k-most":
197196

198197
tracin.influence(
199-
test_samples,
200-
test_labels,
198+
(test_samples, test_labels),
201199
k=2,
202200
proponents=True,
203201
show_progress=True,
@@ -218,8 +216,7 @@ def test_tracin_show_progress(
218216
mock_stderr.truncate(0)
219217

220218
tracin.influence(
221-
test_samples,
222-
test_labels,
219+
(test_samples, test_labels),
223220
k=2,
224221
proponents=False,
225222
show_progress=True,

0 commit comments

Comments
 (0)