Skip to content

Commit 9b14f36

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
update tracin influence API
Summary: This diff changes the API for implementations of `TracInCPBase` as discussed in https://fb.quip.com/JbpnAiWluZmI. In particular, the arguments representing test data of the `influence` method are changed from `inputs: Tuple, targets: Optional[Tensor]` to `inputs: Union[Tuple[Any], DataLoader]`, which is either a single batch, or a dataloader yielding batches. In both cases, `model(*batch)` is assumed to produce the predictions for a batch, and `batch[-1]` is assumed to be the labels for a batch. This is the same format assumed of the batches yielded by `train_dataloader`. We make this change for 2 reasons - it unifies the assumptions made of the test data and the assumptions made of the training data - for some implementations, we want to allow the test data to be represented by a dataloader. with the old API, there was no clean way to allow both a single as well as a dataloader to be passed in, since a batch required 2 arguments, but a dataloader only requires 1. For now, all implementations only allow `inputs` to be a tuple (and not a dataloader). This is okay due to inheritance rules. Later on, we will allow some implementations (i.e. `TracInCP`) to accept a dataloader as `inputs`. Other changes: - changes to make documentation. for example, documentation in `TracInCPBase.influence` now refers to the "test dataset" instead of test batch. - the `unpack_inputs` argument is no longer needed for the `influence` methods, and is removed - the usage of `influence` in all the tests is changed to match new API. - signature of helper methods `_influence_batch_tracincp` and `_influence_batch_tracincp_fast` are changed to match new representation of batches. Reviewed By: cyrjano Differential Revision: D41324297 fbshipit-source-id: ee06ef12c79c7bc143b786ae2c5c15d5c8c11684
1 parent ada8c0d commit 9b14f36

10 files changed

+243
-281
lines changed

captum/influence/_core/tracincp.py

Lines changed: 98 additions & 109 deletions
Large diffs are not rendered by default.

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 96 additions & 118 deletions
Large diffs are not rendered by default.

captum/influence/_utils/common.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def _get_k_most_influential_helper(
187187
influence_src_dataloader: DataLoader,
188188
influence_batch_fn: Callable,
189189
inputs: Tuple[Any, ...],
190-
targets: Optional[Tensor],
191190
k: int = 5,
192191
proponents: bool = True,
193192
show_progress: bool = False,
@@ -202,13 +201,12 @@ def _get_k_most_influential_helper(
202201
influence_src_dataloader (DataLoader): The DataLoader, representing training
203202
data, for which we want to compute proponents / opponents.
204203
influence_batch_fn (Callable): A callable that will be called via
205-
`influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch
204+
`influence_batch_fn(inputs, batch)`, where `batch` is a batch
206205
in the `influence_src_dataloader` argument.
207-
inputs (tuple[Any, ...]): A batch of examples. Does not represent labels,
208-
which are passed as `targets`.
209-
targets (Tensor, optional): If computing TracIn scores on a loss function,
210-
these are the labels corresponding to the batch `inputs`.
211-
Default: None
206+
inputs (tuple[Any, ...]): This argument represents the test batch, and is a
207+
single tuple of any, where the last element is assumed to be the labels
208+
for the batch. That is, `model(*batch[0:-1])` produces the output for
209+
`model`, and `batch[-1]` are the labels, if any.
212210
k (int, optional): The number of proponents or opponents to return per test
213211
instance.
214212
Default: 5
@@ -270,7 +268,7 @@ def _get_k_most_influential_helper(
270268
for batch in influence_src_dataloader:
271269

272270
# calculate tracin_scores for the batch
273-
batch_tracin_scores = influence_batch_fn(inputs, targets, batch)
271+
batch_tracin_scores = influence_batch_fn(inputs, batch)
274272
batch_tracin_scores *= multiplier
275273

276274
# get the top-k indices and tracin_scores for the batch

tests/influence/_core/test_tracin_intermediate_quantities.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from parameterized import parameterized
1212
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1313
from tests.influence._utils.common import (
14+
_format_batch_into_tuple,
1415
build_test_name_func,
1516
DataInfluenceConstructor,
1617
get_random_model_and_data,
@@ -162,25 +163,13 @@ def test_tracin_intermediate_quantities_consistent(
162163
)
163164

164165
# compute influence scores without using `compute_intermediate_quantities`
166+
test_batch = _format_batch_into_tuple(
167+
test_features, test_labels, unpack_inputs
168+
)
165169
scores = tracin.influence(
166-
test_features, test_labels, unpack_inputs=unpack_inputs
170+
test_batch,
167171
)
168172

169-
# compute influence scores using `compute_intermediate_quantities`
170-
# we combine `test_features` and `test_labels` into a single tuple
171-
# `test_batch` to pass to the model, with the assumption that
172-
# `model(test_batch[0:-1]` produces the predictions, and `test_batch[-1]`
173-
# are the labels. We do this due to the assumptions made by the
174-
# `compute_intermediate_quantities` method. Therefore, how we
175-
# form `test_batch` depends on whether `unpack_inputs` is True or False
176-
if not unpack_inputs:
177-
# `test_features` is a Tensor
178-
test_batch = (test_features, test_labels)
179-
else:
180-
# `test_features` is a tuple, so we unpack it to place in tuple,
181-
# along with `test_labels`
182-
test_batch = (*test_features, test_labels) # type: ignore[assignment]
183-
184173
# the influence score is the dot product of intermediate quantities
185174
intermediate_quantities_scores = torch.matmul(
186175
intermediate_quantities_tracin.compute_intermediate_quantities(

tests/influence/_core/test_tracin_k_most_influential.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parameterized import parameterized
99
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1010
from tests.influence._utils.common import (
11+
_format_batch_into_tuple,
1112
build_test_name_func,
1213
DataInfluenceConstructor,
1314
get_random_model_and_data,
@@ -107,15 +108,14 @@ def test_tracin_k_most_influential(
107108
)
108109

109110
train_scores = tracin.influence(
110-
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
111+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
112+
k=None,
111113
)
112114
sort_idx = torch.argsort(train_scores, dim=1, descending=proponents)[:, 0:k]
113115
idx, _train_scores = tracin.influence(
114-
test_samples,
115-
test_labels,
116+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
116117
k=k,
117118
proponents=proponents,
118-
unpack_inputs=unpack_inputs,
119119
)
120120
for i in range(len(idx)):
121121
# check that idx[i] is correct

tests/influence/_core/test_tracin_regression.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,19 @@ def test_tracin_regression(
165165
criterion,
166166
)
167167

168-
train_scores = tracin.influence(train_inputs, train_labels)
168+
train_scores = tracin.influence((train_inputs, train_labels))
169169
idx, _ = tracin.influence(
170-
train_inputs, train_labels, k=len(dataset), proponents=True
170+
(train_inputs, train_labels), k=len(dataset), proponents=True
171171
)
172172
# check that top influence is one with maximal value
173173
# (and hence gradient)
174174
for i in range(len(idx)):
175175
self.assertEqual(idx[i][0], 15)
176176

177177
# check influence scores of test data
178-
test_scores = tracin.influence(test_inputs, test_labels)
178+
test_scores = tracin.influence((test_inputs, test_labels))
179179
idx, _ = tracin.influence(
180-
test_inputs, test_labels, k=len(test_inputs), proponents=True
180+
(test_inputs, test_labels), k=len(test_inputs), proponents=True
181181
)
182182
# check that top influence is one with maximal value
183183
# (and hence gradient)
@@ -208,17 +208,17 @@ def test_tracin_regression(
208208
sample_wise_grads_per_batch=True,
209209
)
210210

211-
train_scores = tracin.influence(train_inputs, train_labels)
211+
train_scores = tracin.influence((train_inputs, train_labels))
212212
train_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
213-
train_inputs, train_labels
213+
(train_inputs, train_labels)
214214
)
215215
assertTensorAlmostEqual(
216216
self, train_scores, train_scores_sample_wise_trick
217217
)
218218

219-
test_scores = tracin.influence(test_inputs, test_labels)
219+
test_scores = tracin.influence((test_inputs, test_labels))
220220
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
221-
test_inputs, test_labels
221+
(test_inputs, test_labels)
222222
)
223223
assertTensorAlmostEqual(
224224
self, test_scores, test_scores_sample_wise_trick
@@ -270,7 +270,7 @@ def test_tracin_regression_1D_numerical(
270270
criterion,
271271
)
272272

273-
train_scores = tracin.influence(train_inputs, train_labels, k=None)
273+
train_scores = tracin.influence((train_inputs, train_labels), k=None)
274274

275275
r"""
276276
Derivation for gradient / resulting TracIn score:
@@ -364,9 +364,9 @@ def test_tracin_identity_regression(
364364

365365
# check influence scores of training data
366366

367-
train_scores = tracin.influence(train_inputs, train_labels)
367+
train_scores = tracin.influence((train_inputs, train_labels))
368368
idx, _ = tracin.influence(
369-
train_inputs, train_labels, k=len(dataset), proponents=True
369+
(train_inputs, train_labels), k=len(dataset), proponents=True
370370
)
371371

372372
# check that top influence for an instance is itself
@@ -397,9 +397,9 @@ def test_tracin_identity_regression(
397397
sample_wise_grads_per_batch=True,
398398
)
399399

400-
train_scores = tracin.influence(train_inputs, train_labels)
400+
train_scores = tracin.influence((train_inputs, train_labels))
401401
train_scores_tracin_sample_wise_trick = (
402-
tracin_sample_wise_trick.influence(train_inputs, train_labels)
402+
tracin_sample_wise_trick.influence((train_inputs, train_labels))
403403
)
404404
assertTensorAlmostEqual(
405405
self, train_scores, train_scores_tracin_sample_wise_trick

tests/influence/_core/test_tracin_self_influence.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parameterized import parameterized
99
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1010
from tests.influence._utils.common import (
11+
_format_batch_into_tuple,
1112
build_test_name_func,
1213
DataInfluenceConstructor,
1314
get_random_model_and_data,
@@ -108,10 +109,10 @@ def test_tracin_self_influence(
108109
criterion,
109110
)
110111
train_scores = tracin.influence(
111-
train_dataset.samples,
112-
train_dataset.labels,
112+
_format_batch_into_tuple(
113+
train_dataset.samples, train_dataset.labels, unpack_inputs
114+
),
113115
k=None,
114-
unpack_inputs=unpack_inputs,
115116
)
116117
# calculate self_tracin_scores
117118
self_tracin_scores = tracin.self_influence(

tests/influence/_core/test_tracin_show_progress.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ def test_tracin_show_progress(
178178
elif mode == "influence":
179179

180180
tracin.influence(
181-
test_samples,
182-
test_labels,
181+
(test_samples, test_labels),
183182
k=None,
184183
show_progress=True,
185184
)
@@ -196,8 +195,7 @@ def test_tracin_show_progress(
196195
elif mode == "k-most":
197196

198197
tracin.influence(
199-
test_samples,
200-
test_labels,
198+
(test_samples, test_labels),
201199
k=2,
202200
proponents=True,
203201
show_progress=True,
@@ -218,8 +216,7 @@ def test_tracin_show_progress(
218216
mock_stderr.truncate(0)
219217

220218
tracin.influence(
221-
test_samples,
222-
test_labels,
219+
(test_samples, test_labels),
223220
k=2,
224221
proponents=False,
225222
show_progress=True,

tests/influence/_core/test_tracin_xor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_tracin_xor(
258258
batch_size,
259259
criterion,
260260
)
261-
test_scores = tracin.influence(testset, testlabels)
261+
test_scores = tracin.influence((testset, testlabels))
262262
idx = torch.argsort(test_scores, dim=1, descending=True)
263263
# check that top 5 influences have matching binary classification
264264
for i in range(len(idx)):
@@ -288,9 +288,9 @@ def test_tracin_xor(
288288
criterion,
289289
sample_wise_grads_per_batch=True,
290290
)
291-
test_scores = tracin.influence(testset, testlabels)
291+
test_scores = tracin.influence((testset, testlabels))
292292
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
293-
testset, testlabels
293+
(testset, testlabels)
294294
)
295295
assertTensorAlmostEqual(
296296
self, test_scores, test_scores_sample_wise_trick

tests/influence/_utils/common.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import unittest
44
from functools import partial
5-
from typing import Callable, Iterator, List, Optional, Union
5+
from typing import Callable, Iterator, List, Optional, Tuple, Union
66

77
import torch
88
import torch.nn as nn
@@ -14,6 +14,7 @@
1414
)
1515
from parameterized import parameterized
1616
from parameterized.parameterized import param
17+
from torch import Tensor
1718
from torch.nn import Module
1819
from torch.utils.data import DataLoader, Dataset
1920

@@ -344,3 +345,12 @@ def build_test_name_func(args_to_skip: Optional[List[str]] = None):
344345
"""
345346

346347
return partial(generate_test_name, args_to_skip=args_to_skip)
348+
349+
350+
def _format_batch_into_tuple(
351+
inputs: Union[Tuple, Tensor], targets: Tensor, unpack_inputs: bool
352+
):
353+
if unpack_inputs:
354+
return (*inputs, targets)
355+
else:
356+
return (inputs, targets)

0 commit comments

Comments
 (0)