Skip to content

Commit 24cd933

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
update tracin influence API (#1072)
Summary: Pull Request resolved: #1072 This diff changes the API for implementations of `TracInCPBase` as discussed in https://fb.quip.com/JbpnAiWluZmI. In particular, the arguments representing test data of the `influence` method are changed from `inputs: Tuple, targets: Optional[Tensor]` to `inputs: Union[Tuple[Any], DataLoader]`, which is either a single batch, or a dataloader yielding batches. In both cases, `model(*batch)` is assumed to produce the predictions for a batch, and `batch[-1]` is assumed to be the labels for a batch. This is the same format assumed of the batches yielded by `train_dataloader`. We make this change for 2 reasons - it unifies the assumptions made of the test data and the assumptions made of the training data - for some implementations, we want to allow the test data to be represented by a dataloader. with the old API, there was no clean way to allow both a single as well as a dataloader to be passed in, since a batch required 2 arguments, but a dataloader only requires 1. For now, all implementations only allow `inputs` to be a tuple (and not a dataloader). This is okay due to inheritance rules. Later on, we will allow some implementations (i.e. `TracInCP`) to accept a dataloader as `inputs`. Other changes: - changes to make documentation. for example, documentation in `TracInCPBase.influence` now refers to the "test dataset" instead of test batch. - the `unpack_inputs` argument is no longer needed for the `influence` methods, and is removed - the usage of `influence` in all the tests is changed to match new API. - signature of helper methods `_influence_batch_tracincp` and `_influence_batch_tracincp_fast` are changed to match new representation of batches. Reviewed By: cyrjano Differential Revision: D41324297 fbshipit-source-id: f0098b83a486b49059c02f359f093ed3b791688c
1 parent 7c228ac commit 24cd933

14 files changed

+495
-605
lines changed

captum/influence/_core/influence.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,19 @@ class DataInfluence(ABC):
1212
An abstract class to define model data influence skeleton.
1313
"""
1414

15-
def __init_(
16-
self, model: Module, influence_src_dataset: Dataset, **kwargs: Any
17-
) -> None:
15+
def __init_(self, model: Module, train_dataset: Dataset, **kwargs: Any) -> None:
1816
r"""
1917
Args:
2018
model (torch.nn.Module): An instance of pytorch model.
21-
influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
19+
train_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
2220
used to create a PyTorch Dataloader to iterate over the dataset and
2321
its labels. This is the dataset for which we will be seeking for
2422
influential instances. In most cases this is the training dataset.
2523
**kwargs: Additional key-value arguments that are necessary for specific
2624
implementation of `DataInfluence` abstract class.
2725
"""
2826
self.model = model
29-
self.influence_src_dataset = influence_src_dataset
27+
self.train_dataset = train_dataset
3028

3129
@abstractmethod
3230
def influence(self, inputs: Any = None, **kwargs: Any) -> Any:

captum/influence/_core/tracincp.py

Lines changed: 210 additions & 249 deletions
Large diffs are not rendered by default.

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 201 additions & 269 deletions
Large diffs are not rendered by default.

captum/influence/_utils/common.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ def _get_k_most_influential_helper(
189189
influence_src_dataloader: DataLoader,
190190
influence_batch_fn: Callable,
191191
inputs: Tuple[Any, ...],
192-
targets: Optional[Tensor],
193192
k: int = 5,
194193
proponents: bool = True,
195194
show_progress: bool = False,
@@ -204,13 +203,12 @@ def _get_k_most_influential_helper(
204203
influence_src_dataloader (DataLoader): The DataLoader, representing training
205204
data, for which we want to compute proponents / opponents.
206205
influence_batch_fn (Callable): A callable that will be called via
207-
`influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch
206+
`influence_batch_fn(inputs, batch)`, where `batch` is a batch
208207
in the `influence_src_dataloader` argument.
209-
inputs (tuple[Any, ...]): A batch of examples. Does not represent labels,
210-
which are passed as `targets`.
211-
targets (Tensor, optional): If computing TracIn scores on a loss function,
212-
these are the labels corresponding to the batch `inputs`.
213-
Default: None
208+
inputs (tuple[Any, ...]): This argument represents the test batch, and is a
209+
single tuple of any, where the last element is assumed to be the labels
210+
for the batch. That is, `model(*batch[0:-1])` produces the output for
211+
`model`, and `batch[-1]` are the labels, if any.
214212
k (int, optional): The number of proponents or opponents to return per test
215213
instance.
216214
Default: 5
@@ -272,7 +270,7 @@ def _get_k_most_influential_helper(
272270
for batch in influence_src_dataloader:
273271

274272
# calculate tracin_scores for the batch
275-
batch_tracin_scores = influence_batch_fn(inputs, targets, batch)
273+
batch_tracin_scores = influence_batch_fn(inputs, batch)
276274
batch_tracin_scores *= multiplier
277275

278276
# get the top-k indices and tracin_scores for the batch

tests/influence/_core/test_dataloader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from parameterized import parameterized
1111
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1212
from tests.influence._utils.common import (
13+
_format_batch_into_tuple,
1314
build_test_name_func,
1415
DataInfluenceConstructor,
1516
get_random_model_and_data,
@@ -76,7 +77,8 @@ def test_tracin_dataloader(
7677
)
7778

7879
train_scores = tracin.influence(
79-
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
80+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
81+
k=None,
8082
)
8183

8284
tracin_dataloader = tracin_constructor(
@@ -88,7 +90,8 @@ def test_tracin_dataloader(
8890
)
8991

9092
train_scores_dataloader = tracin_dataloader.influence(
91-
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
93+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
94+
k=None,
9295
)
9396

9497
assertTensorAlmostEqual(

tests/influence/_core/test_tracin_intermediate_quantities.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from parameterized import parameterized
1313
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1414
from tests.influence._utils.common import (
15+
_format_batch_into_tuple,
1516
build_test_name_func,
1617
DataInfluenceConstructor,
1718
get_random_model_and_data,
@@ -224,25 +225,13 @@ def test_tracin_intermediate_quantities_consistent(
224225
)
225226

226227
# compute influence scores without using `compute_intermediate_quantities`
228+
test_batch = _format_batch_into_tuple(
229+
test_features, test_labels, unpack_inputs
230+
)
227231
scores = tracin.influence(
228-
test_features, test_labels, unpack_inputs=unpack_inputs
232+
test_batch,
229233
)
230234

231-
# compute influence scores using `compute_intermediate_quantities`
232-
# we combine `test_features` and `test_labels` into a single tuple
233-
# `test_batch` to pass to the model, with the assumption that
234-
# `model(test_batch[0:-1]` produces the predictions, and `test_batch[-1]`
235-
# are the labels. We do this due to the assumptions made by the
236-
# `compute_intermediate_quantities` method. Therefore, how we
237-
# form `test_batch` depends on whether `unpack_inputs` is True or False
238-
if not unpack_inputs:
239-
# `test_features` is a Tensor
240-
test_batch = (test_features, test_labels)
241-
else:
242-
# `test_features` is a tuple, so we unpack it to place in tuple,
243-
# along with `test_labels`
244-
test_batch = (*test_features, test_labels) # type: ignore[assignment]
245-
246235
# the influence score is the dot product of intermediate quantities
247236
intermediate_quantities_scores = torch.matmul(
248237
intermediate_quantities_tracin.compute_intermediate_quantities(

tests/influence/_core/test_tracin_k_most_influential.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parameterized import parameterized
99
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1010
from tests.influence._utils.common import (
11+
_format_batch_into_tuple,
1112
build_test_name_func,
1213
DataInfluenceConstructor,
1314
get_random_model_and_data,
@@ -107,15 +108,14 @@ def test_tracin_k_most_influential(
107108
)
108109

109110
train_scores = tracin.influence(
110-
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
111+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
112+
k=None,
111113
)
112114
sort_idx = torch.argsort(train_scores, dim=1, descending=proponents)[:, 0:k]
113115
idx, _train_scores = tracin.influence(
114-
test_samples,
115-
test_labels,
116+
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
116117
k=k,
117118
proponents=proponents,
118-
unpack_inputs=unpack_inputs,
119119
)
120120
for i in range(len(idx)):
121121
# check that idx[i] is correct

tests/influence/_core/test_tracin_regression.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,19 @@ def test_tracin_regression(
183183
criterion,
184184
)
185185

186-
train_scores = tracin.influence(train_inputs, train_labels)
186+
train_scores = tracin.influence((train_inputs, train_labels))
187187
idx, _ = tracin.influence(
188-
train_inputs, train_labels, k=len(dataset), proponents=True
188+
(train_inputs, train_labels), k=len(dataset), proponents=True
189189
)
190190
# check that top influence is one with maximal value
191191
# (and hence gradient)
192192
for i in range(len(idx)):
193193
self.assertEqual(idx[i][0], 15)
194194

195195
# check influence scores of test data
196-
test_scores = tracin.influence(test_inputs, test_labels)
196+
test_scores = tracin.influence((test_inputs, test_labels))
197197
idx, _ = tracin.influence(
198-
test_inputs, test_labels, k=len(test_inputs), proponents=True
198+
(test_inputs, test_labels), k=len(test_inputs), proponents=True
199199
)
200200
# check that top influence is one with maximal value
201201
# (and hence gradient)
@@ -226,17 +226,17 @@ def test_tracin_regression(
226226
sample_wise_grads_per_batch=True,
227227
)
228228

229-
train_scores = tracin.influence(train_inputs, train_labels)
229+
train_scores = tracin.influence((train_inputs, train_labels))
230230
train_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
231-
train_inputs, train_labels
231+
(train_inputs, train_labels)
232232
)
233233
assertTensorAlmostEqual(
234234
self, train_scores, train_scores_sample_wise_trick
235235
)
236236

237-
test_scores = tracin.influence(test_inputs, test_labels)
237+
test_scores = tracin.influence((test_inputs, test_labels))
238238
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
239-
test_inputs, test_labels
239+
(test_inputs, test_labels)
240240
)
241241
assertTensorAlmostEqual(
242242
self, test_scores, test_scores_sample_wise_trick
@@ -288,7 +288,7 @@ def test_tracin_regression_1D_numerical(
288288
criterion,
289289
)
290290

291-
train_scores = tracin.influence(train_inputs, train_labels, k=None)
291+
train_scores = tracin.influence((train_inputs, train_labels), k=None)
292292

293293
r"""
294294
Derivation for gradient / resulting TracIn score:
@@ -382,9 +382,9 @@ def test_tracin_identity_regression(
382382

383383
# check influence scores of training data
384384

385-
train_scores = tracin.influence(train_inputs, train_labels)
385+
train_scores = tracin.influence((train_inputs, train_labels))
386386
idx, _ = tracin.influence(
387-
train_inputs, train_labels, k=len(dataset), proponents=True
387+
(train_inputs, train_labels), k=len(dataset), proponents=True
388388
)
389389

390390
# check that top influence for an instance is itself
@@ -415,9 +415,9 @@ def test_tracin_identity_regression(
415415
sample_wise_grads_per_batch=True,
416416
)
417417

418-
train_scores = tracin.influence(train_inputs, train_labels)
418+
train_scores = tracin.influence((train_inputs, train_labels))
419419
train_scores_tracin_sample_wise_trick = (
420-
tracin_sample_wise_trick.influence(train_inputs, train_labels)
420+
tracin_sample_wise_trick.influence((train_inputs, train_labels))
421421
)
422422
assertTensorAlmostEqual(
423423
self, train_scores, train_scores_tracin_sample_wise_trick
@@ -496,5 +496,5 @@ def test_loss_fn(input, target):
496496
)
497497

498498
# check influence scores of training data. they should all be 0
499-
train_scores = tracin.influence(train_inputs, train_labels, k=None)
499+
train_scores = tracin.influence((train_inputs, train_labels), k=None)
500500
assertTensorAlmostEqual(self, train_scores, torch.zeros(train_scores.shape))

tests/influence/_core/test_tracin_self_influence.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parameterized import parameterized
99
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1010
from tests.influence._utils.common import (
11+
_format_batch_into_tuple,
1112
build_test_name_func,
1213
DataInfluenceConstructor,
1314
get_random_model_and_data,
@@ -108,10 +109,10 @@ def test_tracin_self_influence(
108109
criterion,
109110
)
110111
train_scores = tracin.influence(
111-
train_dataset.samples,
112-
train_dataset.labels,
112+
_format_batch_into_tuple(
113+
train_dataset.samples, train_dataset.labels, unpack_inputs
114+
),
113115
k=None,
114-
unpack_inputs=unpack_inputs,
115116
)
116117
# calculate self_tracin_scores
117118
self_tracin_scores = tracin.self_influence(

tests/influence/_core/test_tracin_show_progress.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ def test_tracin_show_progress(
178178
elif mode == "influence":
179179

180180
tracin.influence(
181-
test_samples,
182-
test_labels,
181+
(test_samples, test_labels),
183182
k=None,
184183
show_progress=True,
185184
)
@@ -196,8 +195,7 @@ def test_tracin_show_progress(
196195
elif mode == "k-most":
197196

198197
tracin.influence(
199-
test_samples,
200-
test_labels,
198+
(test_samples, test_labels),
201199
k=2,
202200
proponents=True,
203201
show_progress=True,
@@ -218,8 +216,7 @@ def test_tracin_show_progress(
218216
mock_stderr.truncate(0)
219217

220218
tracin.influence(
221-
test_samples,
222-
test_labels,
219+
(test_samples, test_labels),
223220
k=2,
224221
proponents=False,
225222
show_progress=True,

0 commit comments

Comments
 (0)