From 03cbff1cb18d92a7e41081b0e31eef419dc025b6 Mon Sep 17 00:00:00 2001 From: Fulton Wang Date: Sun, 31 Jul 2022 16:14:35 -0700 Subject: [PATCH 1/2] modify tracin self influence helpers (#994) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/994 change `TracInCP._self_influence_batch_tracincp` and `TracInCP._self_influence_batch_tracincp` `TracInCP._self_influence_batches_tracincp_fast` to be named `self_influence`, which is now public, and now accept a DataLoader yielding batches (as well as a single batch, as before). The modified helper function can be called by external functions to compute self influence. The helper itself is also changed to improve efficiency, by reducing the number of times checkpoints are loaded. The modified helper, despite being able to compute self influence scores for a dataloader yielding batches, still only loads each checkpoint once, per call. This is because the modified helper now has an outer iteration over checkpoints, and an inner iteration over batches (the order of iteration is reversed compared to before). This helper is called by `influence` when running it in self influence mode. The reason we cannot just increase the batch size to reduce the number of checkpoint loadings is that for large models (precisely those for which loading checkpoints is expensive), the model takes up too much memory, so that the batch size cannot be too large. Minor change: for `influence_src_dataset` argument of all `__init__`'s, add description of what assumptions we make of the batches yielded by the dataloader. Differential Revision: D35603078 fbshipit-source-id: 87063052e68441b82514489f4d9f9ad29b396da4 --- captum/influence/_core/tracincp.py | 434 +++++++++++------- .../_core/tracincp_fast_rand_proj.py | 386 ++++++++++------ captum/influence/_utils/common.py | 12 + .../_core/test_tracin_self_influence.py | 70 ++- .../_core/test_tracin_show_progress.py | 217 +++++---- 5 files changed, 739 insertions(+), 380 deletions(-) diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index d5acc2dfef..78fa32738f 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -26,6 +26,7 @@ from captum._utils.progress import progress from captum.influence._core.influence import DataInfluence from captum.influence._utils.common import ( + _format_inputs_dataset, _get_k_most_influential_helper, _gradient_dot_product, _load_flexible_state_dict, @@ -95,7 +96,7 @@ class TracInCPBase(DataInfluence): def __init__( self, model: Module, - influence_src_dataset: Union[Dataset, DataLoader], + train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, @@ -105,7 +106,7 @@ def __init__( Args: model (torch.nn.Module): An instance of pytorch model. This model should define all of its layers as attributes of the model. - influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + train_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): In the `influence` method, we either compute the influence score of training examples on examples in a test batch, or self influence scores for those training examples, depending on which mode is used. @@ -120,9 +121,15 @@ def __init__( DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if - `influence_src_dataset` is a Dataset, `batch_size` should be large. - If `influence_src_dataset` was already a DataLoader to begin with, - it should have been constructed to have a large batch size. + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. checkpoints (str or List of str or Iterator): Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which @@ -140,12 +147,12 @@ def __init__( loss_fn (Callable, optional): The loss function applied to model. Default: None batch_size (int or None, optional): Batch size of the DataLoader created to - iterate through `influence_src_dataset`, if it is a Dataset. + iterate through `train_dataset`, if it is a Dataset. `batch_size` should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of `TracInCPBase` will detail the size of the intermediate quantities. `batch_size` must be an int if - `influence_src_dataset` is a Dataset. If `influence_src_dataset` + `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 """ @@ -165,44 +172,80 @@ def __init__( self.loss_fn = loss_fn self.batch_size = batch_size - if not isinstance(influence_src_dataset, DataLoader): + if not isinstance(train_dataset, DataLoader): assert isinstance(batch_size, int), ( - "since the `influence_src_dataset` argument was a `Dataset`, " + "since the `train_dataset` argument was a `Dataset`, " "`batch_size` must be an int." ) - self.influence_src_dataloader = DataLoader( - influence_src_dataset, batch_size, shuffle=False - ) + self.train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False) else: - self.influence_src_dataloader = influence_src_dataset + self.train_dataloader = train_dataset - self.influence_src_dataloader_len: Optional[int] = None + self.train_dataloader_len: Optional[int] = None try: # since we will calculate the number of batches in - # `self.influence_src_dataloader` whenever we use progress bar, calculate + # `self.train_dataloader` whenever we use progress bar, calculate # it once in initialization, for re-use. - self.influence_src_dataloader_len = len(self.influence_src_dataloader) - except AttributeError: - pass + self.train_dataloader_len = len(self.train_dataloader) + except TypeError: + warnings.warn( + "Unable to determine the number of batches in training dataset " + "`train_dataset`. Therefore, if showing the progress of computations, " + "only the number of batches processed can be displayed, and not the " + "percentage completion of the computation, nor any time estimates." + ) @abstractmethod - def _self_influence(self, show_progress: bool = False): + def self_influence( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: """ - Returns: - self influence scores (tensor): 1D tensor containing self influence - scores for all examples in training dataset - `influence_src_dataset`. - show_progress (bool, optional): To compute the self influence scores for - all examples in training dataset `influence_src_dataset`, we - compute the self influence scores for each batch. If + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Args: + batches (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. Please see documentation for the + `train_dataset` argument to `TracInCP.__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If `show_progress`is true, the progress of this computation will be - displayed. In particular, the number of batches for which self - influence scores have been computed will be displayed. It will - try to use tqdm if available for advanced features (e.g. time - estimation). Otherwise, it will fallback to a simple output of - progress. + displayed. In more detail, this computation will iterate over all + checkpoints (provided as the `checkpoints` initialization argument) + in an outer loop, and iterate over all batches that + `inputs_dataset` represents in an inner loop. Therefore, the + total number of (checkpoint, batch) combinations that need to be + iterated over is + (# of checkpoints x # of batches that `inputs_dataset` represents). + If `show_progress` is True, the total progress of both the outer + iteration over checkpoints and the inner iteration over batches is + displayed. It will try to use tqdm if available for advanced + features (e.g. time estimation). Otherwise, it will fallback to a + simple output of progress. Default: False + + Returns: + self_influence_scores (Tensor): This is a 1D tensor containing the self + influence scores of all examples in `inputs_dataset`, regardless of + whether it represents a single batch or a `DataLoader` that yields + batches. """ pass @@ -230,7 +273,7 @@ def _get_k_most_influential( Default: True show_progress (bool, optional): To compute the proponents (or opponents) for the batch of examples, we perform computation for each batch in - training dataset `influence_src_dataset`, If `show_progress`is + training dataset `train_dataset`, If `show_progress`is true, the progress of this computation will be displayed. In particular, the number of batches for which the computation has been performed will be displayed. It will try to use tqdm if @@ -244,13 +287,13 @@ def _get_k_most_influential( test example. Its dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the number of examples in `inputs`. For example, if `proponents==True`, `indices[i][j]` is the index of the - example in training dataset `influence_src_dataset` with the + example in training dataset `train_dataset` with the k-th highest influence score for the j-th example in `inputs`. `indices` is a `torch.long` tensor so that it can directly be used to index other tensors. Each row of `influence_scores` contains the influence scores for a different test example, in sorted order. In particular, `influence_scores[i][j]` is the influence score of - example `indices[i][j]` in training dataset `influence_src_dataset` + example `indices[i][j]` in training dataset `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -267,7 +310,7 @@ def _influence( Args: inputs (Tuple of Any): A batch of examples. Does not represent labels, which are passed as `targets`. The assumption is that - `self.model(*inputs)` produces the predictions for the batch. + `model(*inputs)` produces the predictions for the batch. targets (tensor, optional): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. @@ -275,12 +318,12 @@ def _influence( Returns: influence_scores (tensor): Influence scores over the entire - training dataset `influence_src_dataset`. Dimensionality is + training dataset `train_dataset`. Dimensionality is (inputs_batch_size, src_dataset_size). For example: influence_scores[i][j] = the influence score for the j-th training example to the i-th input example. show_progress (bool, optional): To compute the influence of examples in - training dataset `influence_src_dataset`, we compute the influence + training dataset `train_dataset`, we compute the influence of each batch. If `show_progress`is true, the progress of this computation will be displayed. In particular, the number of batches for which influence has been computed will be displayed. It will @@ -307,17 +350,17 @@ def influence( # type: ignore[override] - self influence mode: This mode is used if `inputs` is None. This mode computes the self influence scores for every example in - the training dataset `influence_src_dataset`. + the training dataset `train_dataset`. - influence score mode: This mode is used if `inputs` is not None, and `k` is None. This mode computes the influence score of every example in - training dataset `influence_src_dataset` on every example in the test + training dataset `train_dataset` on every example in the test batch represented by `inputs` and `targets`. - k-most influential mode: This mode is used if `inputs` is not None, and `k` is not None, and an int. This mode computes the proponents or opponents of every example in the test batch represented by `inputs` and `targets`. In particular, for each test example in the test batch, this mode computes its proponents (resp. opponents), which are the - indices in the training dataset `influence_src_dataset` of the training + indices in the training dataset `train_dataset` of the training examples with the `k` highest (resp. lowest) influence scores on the test example. Proponents are computed if `proponents` is True. Otherwise, opponents are computed. For each test example, this method @@ -329,12 +372,12 @@ def influence( # type: ignore[override] will be run. Otherwise, `inputs` is the test batch that will be used when running in either influence score or k-most influential mode. If the argument `unpack_inputs` is False, the - assumption is that `self.model(inputs)` produces the predictions + assumption is that `model(inputs)` produces the predictions for a batch, and `inputs` can be of any type. Otherwise if the argument `unpack_inputs` is True, the assumption is that - `self.model(*inputs)` produces the predictions for a batch, and + `model(*inputs)` produces the predictions for a batch, and `inputs` will need to be a tuple. In other words, `inputs` will be - unpacked as an argument when passing to `self.model`. + unpacked as an argument when passing to `model`. Default: None targets (tensor, optional): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. @@ -354,7 +397,7 @@ def influence( # type: ignore[override] Default: True show_progress (bool, optional): For all modes, computation of results requires "training dataset computations": computations for each - batch in the training dataset `influence_src_dataset`, which may + batch in the training dataset `train_dataset`, which may take a long time. If `show_progress`is true, the progress of "training dataset computations" will be displayed. In particular, the number of batches for which computations have been performed @@ -368,29 +411,29 @@ def influence( # type: ignore[override] - self influence mode: if this mode is run (`inputs` is None), returns a 1D tensor of self influence scores over training dataset - `influence_src_dataset`. The length of this tensor is the number of - examples in `influence_src_dataset`, regardless of whether it is a + `train_dataset`. The length of this tensor is the number of + examples in `train_dataset`, regardless of whether it is a Dataset or DataLoader. - influence score mode: if this mode is run (`inputs is not None, `k` is None), returns a 2D tensor `influence_scores` of shape - `(input_size, influence_src_dataset_size)`, where `input_size` is + `(input_size, train_dataset_size)`, where `input_size` is the number of examples in the test batch, and - `influence_src_dataset_size` is the number of examples in - training dataset `influence_src_dataset`. In other words, + `train_dataset_size` is the number of examples in + training dataset `train_dataset`. In other words, `influence_scores[i][j]` is the influence score of the `j`-th - example in `influence_src_dataset` on the `i`-th example in the + example in `train_dataset` on the `i`-th example in the test batch. - k-most influential mode: if this mode is run (`inputs` is not None, `k` is an int), returns a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of shape `(input_size, k)`, where `input_size` is the number of examples in the test batch. If computing proponents (resp. opponents), `indices[i][j]` is the - index in training dataset `influence_src_dataset` of the example + index in training dataset `train_dataset` of the example with the `j`-th highest (resp. lowest) influence score (out of the - examples in `influence_src_dataset`) on the `i`-th example in the + examples in `train_dataset`) on the `i`-th example in the test batch. `influence_scores` contains the corresponding influence scores. In particular, `influence_scores[i][j]` is the influence - score of example `indices[i][j]` in `influence_src_dataset` on + score of example `indices[i][j]` in `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -431,7 +474,9 @@ def _influence_route_to_helpers( _inputs = _format_inputs(inputs, unpack_inputs) if inputs is None: - return influence_instance._self_influence(show_progress) + return influence_instance.self_influence( + influence_instance.train_dataloader, show_progress + ) elif k is None: return influence_instance._influence(_inputs, targets, show_progress) else: @@ -444,7 +489,7 @@ class TracInCP(TracInCPBase): def __init__( self, model: Module, - influence_src_dataset: Union[Dataset, DataLoader], + train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, layers: Optional[List[str]] = None, @@ -456,7 +501,7 @@ def __init__( Args: model (torch.nn.Module): An instance of pytorch model. This model should define all of its layers as attributes of the model. - influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + train_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): In the `influence` method, we either compute the influence score of training examples on examples in a test batch, or self influence scores for those training examples, depending on which mode is used. @@ -471,9 +516,15 @@ def __init__( DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if - `influence_src_dataset` is a Dataset, `batch_size` should be large. - If `influence_src_dataset` was already a DataLoader to begin with, - it should have been constructed to have a large batch size. + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. checkpoints (str or List of str or Iterator): Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which @@ -507,12 +558,12 @@ def __init__( to "mean", i.e. `loss_fn.reduction = "mean"`. Default: None batch_size (int or None, optional): Batch size of the DataLoader created to - iterate through `influence_src_dataset`, if it is a Dataset. + iterate through `train_dataset`, if it is a Dataset. `batch_size` should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of `TracInCPBase` will detail the size of the intermediate quantities. `batch_size` must be an int if - `influence_src_dataset` is a Dataset. If `influence_src_dataset` + `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient @@ -539,7 +590,7 @@ def __init__( TracInCPBase.__init__( self, model, - influence_src_dataset, + train_dataset, checkpoints, checkpoints_load_func, loss_fn, @@ -627,17 +678,17 @@ def influence( # type: ignore[override] - self influence mode: This mode is used if `inputs` is None. This mode computes the self influence scores for every example in - the training dataset `influence_src_dataset`. + the training dataset `train_dataset`. - influence score mode: This mode is used if `inputs` is not None, and `k` is None. This mode computes the influence score of every example in - training dataset `influence_src_dataset` on every example in the test + training dataset `train_dataset` on every example in the test batch represented by `inputs` and `targets`. - k-most influential mode: This mode is used if `inputs` is not None, and `k` is not None, and an int. This mode computes the proponents or opponents of every example in the test batch represented by `inputs` and `targets`. In particular, for each test example in the test batch, this mode computes its proponents (resp. opponents), which are the - indices in the training dataset `influence_src_dataset` of the training + indices in the training dataset `train_dataset` of the training examples with the `k` highest (resp. lowest) influence scores on the test example. Proponents are computed if `proponents` is True. Otherwise, opponents are computed. For each test example, this method @@ -649,12 +700,12 @@ def influence( # type: ignore[override] will be run. Otherwise, `inputs` is the test batch that will be used when running in either influence score or k-most influential mode. If the argument `unpack_inputs` is False, the - assumption is that `self.model(inputs)` produces the predictions + assumption is that `model(inputs)` produces the predictions for a batch, and `inputs` can be of any type. Otherwise if the argument `unpack_inputs` is True, the assumption is that - `self.model(*inputs)` produces the predictions for a batch, and + `model(*inputs)` produces the predictions for a batch, and `inputs` will need to be a tuple. In other words, `inputs` will be - unpacked as an argument when passing to `self.model`. + unpacked as an argument when passing to `model`. Default: None targets (tensor, optional): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. @@ -674,7 +725,7 @@ def influence( # type: ignore[override] Default: True show_progress (bool, optional): For all modes, computation of results requires "training dataset computations": computations for each - batch in the training dataset `influence_src_dataset`, which may + batch in the training dataset `train_dataset`, which may take a long time. If `show_progress`is true, the progress of "training dataset computations" will be displayed. In particular, the number of batches for which computations have been performed @@ -688,29 +739,29 @@ def influence( # type: ignore[override] - self influence mode: if this mode is run (`inputs` is None), returns a 1D tensor of self influence scores over training dataset - `influence_src_dataset`. The length of this tensor is the number of - examples in `influence_src_dataset`, regardless of whether it is a + `train_dataset`. The length of this tensor is the number of + examples in `train_dataset`, regardless of whether it is a Dataset or DataLoader. - influence score mode: if this mode is run (`inputs is not None, `k` is None), returns a 2D tensor `influence_scores` of shape - `(input_size, influence_src_dataset_size)`, where `input_size` is + `(input_size, train_dataset_size)`, where `input_size` is the number of examples in the test batch, and - `influence_src_dataset_size` is the number of examples in - training dataset `influence_src_dataset`. In other words, + `train_dataset_size` is the number of examples in + training dataset `train_dataset`. In other words, `influence_scores[i][j]` is the influence score of the `j`-th - example in `influence_src_dataset` on the `i`-th example in the + example in `train_dataset` on the `i`-th example in the test batch. - k-most influential mode: if this mode is run (`inputs` is not None, `k` is an int), returns a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of shape `(input_size, k)`, where `input_size` is the number of examples in the test batch. If computing proponents (resp. opponents), `indices[i][j]` is the - index in training dataset `influence_src_dataset` of the example + index in training dataset `train_dataset` of the example with the `j`-th highest (resp. lowest) influence score (out of the - examples in `influence_src_dataset`) on the `i`-th example in the + examples in `train_dataset`) on the `i`-th example in the test batch. `influence_scores` contains the corresponding influence scores. In particular, `influence_scores[i][j]` is the influence - score of example `indices[i][j]` in `influence_src_dataset` on + score of example `indices[i][j]` in `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -769,7 +820,7 @@ def _influence( show_progress: bool = False, ) -> Tensor: r""" - Computes the influence of examples in training dataset `influence_src_dataset` + Computes the influence of examples in training dataset `train_dataset` on the examples in the test batch represented by `inputs` and `targets`. This implementation does not require knowing the number of training examples in advance. Instead, the number of training examples is inferred from the @@ -778,12 +829,12 @@ def _influence( Args: inputs (Tuple of Any): A test batch of examples. Does not represent labels, which are passed as `targets`. The assumption is that - `self.model(*inputs)` produces the predictions for the batch. + `model(*inputs)` produces the predictions for the batch. targets (tensor, optional): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. Default: None show_progress (bool, optional): To compute the influence of examples in - training dataset `influence_src_dataset`, we compute the influence + training dataset `train_dataset`, we compute the influence of each batch. If `show_progress`is true, the progress of this computation will be displayed. In particular, the number of batches for which influence has been computed will be displayed. It will @@ -794,29 +845,29 @@ def _influence( Returns: influence_scores (tensor): Influence scores from the TracInCP method. - Its shape is `(input_size, influence_src_dataset_size)`, where `input_size` + Its shape is `(input_size, train_dataset_size)`, where `input_size` is the number of examples in the test batch, and - `influence_src_dataset_size` is the number of examples in - training dataset `influence_src_dataset`. For example: + `train_dataset_size` is the number of examples in + training dataset `train_dataset`. For example: `influence_scores[i][j]` is the influence score for the j-th training example to the i-th input example. """ - influence_src_dataloader = self.influence_src_dataloader + train_dataloader = self.train_dataloader if show_progress: - influence_src_dataloader = progress( - influence_src_dataloader, + train_dataloader = progress( + train_dataloader, desc=( f"Using {self.get_name()} to compute " "influence for training batches" ), - total=self.influence_src_dataloader_len, + total=self.train_dataloader_len, ) return torch.cat( [ self._influence_batch_tracincp(inputs, targets, batch) - for batch in influence_src_dataloader + for batch in train_dataloader ], dim=1, ) @@ -844,7 +895,7 @@ def _get_k_most_influential( Default: True show_progress (bool, optional): To compute the proponents (or opponents) for the batch of examples, we perform computation for each batch in - training dataset `influence_src_dataset`, If `show_progress`is + training dataset `train_dataset`, If `show_progress`is true, the progress of this computation will be displayed. In particular, the number of batches for which the computation has been performed will be displayed. It will try to use tqdm if @@ -858,13 +909,13 @@ def _get_k_most_influential( test example. Its dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the number of examples in `inputs`. For example, if `proponents==True`, `indices[i][j]` is the index of the - example in training dataset `influence_src_dataset` with the + example in training dataset `train_dataset` with the k-th highest influence score for the j-th example in `inputs`. `indices` is a `torch.long` tensor so that it can directly be used to index other tensors. Each row of `influence_scores` contains the influence scores for a different test example, in sorted order. In particular, `influence_scores[i][j]` is the influence score of - example `indices[i][j]` in training dataset `influence_src_dataset` + example `indices[i][j]` in training dataset `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -881,7 +932,7 @@ def _get_k_most_influential( ) return KMostInfluentialResults( *_get_k_most_influential_helper( - self.influence_src_dataloader, + self.train_dataloader, self._influence_batch_tracincp, inputs, targets, @@ -892,86 +943,159 @@ def _get_k_most_influential( ) ) - def _self_influence_batch_tracincp(self, batch: Tuple[Any, ...]): + def self_influence( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: """ - Computes self influence scores for a single batch + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Args: + batches (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. Please see documentation for the + `train_dataset` argument to `TracInCP.__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In more detail, this computation will iterate over all + checkpoints (provided as the `checkpoints` initialization argument) + in an outer loop, and iterate over all batches that + `inputs_dataset` represents in an inner loop. Therefore, the + total number of (checkpoint, batch) combinations that need to be + iterated over is + (# of checkpoints x # of batches that `inputs_dataset` represents). + If `show_progress` is True, the total progress of both the outer + iteration over checkpoints and the inner iteration over batches is + displayed. It will try to use tqdm if available for advanced + features (e.g. time estimation). Otherwise, it will fallback to a + simple output of progress. + Default: False + + Returns: + self_influence_scores (Tensor): This is a 1D tensor containing the self + influence scores of all examples in `inputs_dataset`, regardless of + whether it represents a single batch or a `DataLoader` that yields + batches. """ + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) - def get_checkpoint_contribution(checkpoint): + # If `show_progress` is true, create an outer progress bar that keeps track of + # how many checkpoints have been processed + if show_progress: + checkpoints_progress = progress( + desc=( + f"Using {self.get_name()} to compute self " + "influence. Processing checkpoint" + ), + total=len(self.checkpoints), + ) + # Try to determine length of inner progress bar if possible, with a default + # of `None`. + inputs_dataset_len = None + try: + inputs_dataset_len = len(inputs_dataset) + except TypeError: + warnings.warn( + "Unable to determine the number of batches in `inputs_dataset`. " + "Therefore, if showing the progress of the computation of self " + "influence scores, only the number of batches processed can be " + "displayed, and not the percentage completion of the computation, " + "nor any time estimates." + ) + def get_checkpoint_contribution(checkpoint): + # This function returns a 1D tensor representing the contribution to the + # self influence score for the given checkpoint, for all batches in + # `inputs_dataset`. The length of the 1D tensor is the total number of + # examples in `inputs_dataset`. assert ( checkpoint is not None ), "None returned from `checkpoints`, cannot load." learning_rate = self.checkpoints_load_func(self.model, checkpoint) - layer_jacobians = self._basic_computation_tracincp(batch[0:-1], batch[-1]) + # This will store a list of the contribution of the self influence score + # from each batch. Each element is a 1D tensor of length batch_size - the + # batch size of each batch in `inputs_dataset` (they do not need to be all + # the same) + checkpoint_contribution = [] + + _inputs_dataset = inputs_dataset + # If `show_progress` is true, create an inner progress bar that keeps track + # of how many batches have been processed for the current checkpoint + if show_progress: + _inputs_dataset = progress( + inputs_dataset, + desc=( + f"Using {self.get_name()} to compute self " + "influence. Processing batch" + ), + total=inputs_dataset_len, + ) - # note that all variables in this function are for an entire batch. - # each `layer_jacobian` in `layer_jacobians` corresponds to a different - # layer. `layer_jacobian` is the jacobian w.r.t to a given layer's - # parameters. if the given layer's parameters are of shape *, then - # `layer_jacobian` is of shape (batch_size, *). for each layer, we need - # the squared jacobian for each example. so we square the jacobian and - # sum over all dimensions except the 0-th (the batch dimension). We then - # sum the contribution over all layers. - return ( - torch.sum( - torch.stack( - [ - torch.sum(layer_jacobian.flatten(start_dim=1) ** 2, dim=1) - for layer_jacobian in layer_jacobians - ], + for batch in _inputs_dataset: + + layer_jacobians = self._basic_computation_tracincp( + batch[0:-1], batch[-1] + ) + + # Note that all variables in this function are for an entire batch. + # Each `layer_jacobian` in `layer_jacobians` corresponds to a different + # layer. `layer_jacobian` is the jacobian w.r.t to a given layer's + # parameters. If the given layer's parameters are of shape *, then + # `layer_jacobian` is of shape (batch_size, *). For each layer, we need + # the squared jacobian for each example. So we square the jacobian and + # sum over all dimensions except the 0-th (the batch dimension). We then + # sum the contribution over all layers. + checkpoint_contribution.append( + torch.sum( + torch.stack( + [ + torch.sum( + layer_jacobian.flatten(start_dim=1) ** 2, dim=1 + ) + for layer_jacobian in layer_jacobians + ], + dim=0, + ), dim=0, - ), - dim=0, + ) + * learning_rate ) - * learning_rate - ) - batch_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) + # We concatenate the contributions from each batch into a single 1D tensor, + # which represents the contributions for all batches in `inputs_dataset` - for checkpoint in self.checkpoints[1:]: - batch_self_tracin_scores += get_checkpoint_contribution(checkpoint) + if show_progress: + checkpoints_progress.update() - return batch_self_tracin_scores + return torch.cat(checkpoint_contribution, dim=0) - def _self_influence(self, show_progress: bool = False): - """ - Returns: - self influence scores (tensor): 1D tensor containing self influence - scores for all examples in training dataset - `influence_src_dataset`. - show_progress (bool, optional): To compute the self influence scores for - all examples in training dataset `influence_src_dataset`, we - compute the self influence scores for each batch. If - `show_progress`is true, the progress of this computation will be - displayed. In particular, the number of batches for which self - influence scores have been computed will be displayed. It will - try to use tqdm if available for advanced features (e.g. time - estimation). Otherwise, it will fallback to a simple output of - progress. - Default: False - """ - influence_src_dataloader = self.influence_src_dataloader + batches_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) - if show_progress: - influence_src_dataloader = progress( - influence_src_dataloader, - desc=( - f"Using {self.get_name()} to compute self " - "influence for training batches" - ), - total=self.influence_src_dataloader_len, - ) + # The self influence score for all examples is the sum of contributions from + # each checkpoint + for checkpoint in self.checkpoints[1:]: + batches_self_tracin_scores += get_checkpoint_contribution(checkpoint) - return torch.cat( - [ - self._self_influence_batch_tracincp(batch) - for batch in influence_src_dataloader - ], - dim=0, - ) + return batches_self_tracin_scores def _basic_computation_tracincp( self, @@ -987,7 +1111,7 @@ def _basic_computation_tracincp( inputs (Tuple of Any): A batch of examples, which could be a training batch or test batch, depending which method is the caller. Does not represent labels, which are passed as `targets`. The assumption is - that `self.model(*inputs)` produces the predictions for the batch. + that `model(*inputs)` produces the predictions for the batch. targets (tensor or None): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. """ diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index cfbf7b47d4..71fe3b45a0 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -13,6 +13,7 @@ ) from captum.influence._utils.common import ( _DatasetFromList, + _format_inputs_dataset, _get_k_most_influential_helper, _jacobian_loss_wrt_inputs, _load_flexible_state_dict, @@ -77,7 +78,7 @@ def __init__( self, model: Module, final_fc_layer: Union[Module, str], - influence_src_dataset: Union[Dataset, DataLoader], + train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, @@ -93,7 +94,7 @@ def __init__( projection method. Can be either the layer module itself, or the fully qualified name of the layer if it is a defined attribute of the passed `model`. - influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + train_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): In the `influence` method, we either compute the influence score of training examples on examples in a test batch, or self influence scores for those training examples, depending on which mode is used. @@ -108,9 +109,15 @@ def __init__( DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if - `influence_src_dataset` is a Dataset, `batch_size` should be large. - If `influence_src_dataset` was already a DataLoader to begin with, - it should have been constructed to have a large batch size. + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. checkpoints (str or List of str or Iterator): Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which @@ -132,12 +139,12 @@ def __init__( to "mean", i.e. `loss_fn.reduction = "mean"`. Default: None batch_size (int or None, optional): Batch size of the DataLoader created to - iterate through `influence_src_dataset`, if it is a Dataset. + iterate through `train_dataset`, if it is a Dataset. `batch_size` should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of `TracInCPBase` will detail the size of the intermediate quantities. `batch_size` must be an int if - `influence_src_dataset` is a Dataset. If `influence_src_dataset` + `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 vectorize (bool, optional): Flag to use experimental vectorize functionality @@ -147,7 +154,7 @@ def __init__( TracInCPBase.__init__( self, model, - influence_src_dataset, + train_dataset, checkpoints, checkpoints_load_func, loss_fn, @@ -206,17 +213,17 @@ def influence( # type: ignore[override] - self influence mode: This mode is used if `inputs` is None. This mode computes the self influence scores for every example in - the training dataset `influence_src_dataset`. + the training dataset `train_dataset`. - influence score mode: This mode is used if `inputs` is not None, and `k` is None. This mode computes the influence score of every example in - training dataset `influence_src_dataset` on every example in the test + training dataset `train_dataset` on every example in the test batch represented by `inputs` and `targets`. - k-most influential mode: This mode is used if `inputs` is not None, and `k` is not None, and an int. This mode computes the proponents or opponents of every example in the test batch represented by `inputs` and `targets`. In particular, for each test example in the test batch, this mode computes its proponents (resp. opponents), which are the - indices in the training dataset `influence_src_dataset` of the training + indices in the training dataset `train_dataset` of the training examples with the `k` highest (resp. lowest) influence scores on the test example. Proponents are computed if `proponents` is True. Otherwise, opponents are computed. For each test example, this method @@ -228,12 +235,12 @@ def influence( # type: ignore[override] will be run. Otherwise, `inputs` is the test batch that will be used when running in either influence score or k-most influential mode. If the argument `unpack_inputs` is False, the - assumption is that `self.model(inputs)` produces the predictions + assumption is that `model(inputs)` produces the predictions for a batch, and `inputs` can be of any type. Otherwise if the argument `unpack_inputs` is True, the assumption is that - `self.model(*inputs)` produces the predictions for a batch, and + `model(*inputs)` produces the predictions for a batch, and `inputs` will need to be a tuple. In other words, `inputs` will be - unpacked as an argument when passing to `self.model`. + unpacked as an argument when passing to `model`. Default: None targets (tensor, optional): The labels corresponding to the batch `inputs`. This method is designed to be applied for a loss function, so @@ -254,7 +261,7 @@ def influence( # type: ignore[override] Default: True show_progress (bool, optional): For all modes, computation of results requires "training dataset computations": computations for each - batch in the training dataset `influence_src_dataset`, which may + batch in the training dataset `train_dataset`, which may take a long time. If `show_progress`is true, the progress of "training dataset computations" will be displayed. In particular, the number of batches for which computations have been performed @@ -268,29 +275,29 @@ def influence( # type: ignore[override] - self influence mode: if this mode is run (`inputs` is None), returns a 1D tensor of self influence scores over training dataset - `influence_src_dataset`. The length of this tensor is the number of - examples in `influence_src_dataset`, regardless of whether it is a + `train_dataset`. The length of this tensor is the number of + examples in `train_dataset`, regardless of whether it is a Dataset or DataLoader. - influence score mode: if this mode is run (`inputs is not None, `k` is None), returns a 2D tensor `influence_scores` of shape - `(input_size, influence_src_dataset_size)`, where `input_size` is + `(input_size, train_dataset_size)`, where `input_size` is the number of examples in the test batch, and - `influence_src_dataset_size` is the number of examples in - training dataset `influence_src_dataset`. In other words, + `train_dataset_size` is the number of examples in + training dataset `train_dataset`. In other words, `influence_scores[i][j]` is the influence score of the `j`-th - example in `influence_src_dataset` on the `i`-th example in the + example in `train_dataset` on the `i`-th example in the test batch. - k-most influential mode: if this mode is run (`inputs` is not None, `k` is an int), returns a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of shape `(input_size, k)`, where `input_size` is the number of examples in the test batch. If computing proponents (resp. opponents), `indices[i][j]` is the - index in training dataset `influence_src_dataset` of the example + index in training dataset `train_dataset` of the example with the `j`-th highest (resp. lowest) influence score (out of the - examples in `influence_src_dataset`) on the `i`-th example in the + examples in `train_dataset`) on the `i`-th example in the test batch. `influence_scores` contains the corresponding influence scores. In particular, `influence_scores[i][j]` is the influence - score of example `indices[i][j]` in `influence_src_dataset` on + score of example `indices[i][j]` in `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -351,7 +358,7 @@ def _influence( # type: ignore[override] show_progress: bool = False, ) -> Tensor: r""" - Computes the influence of examples in training dataset `influence_src_dataset` + Computes the influence of examples in training dataset `train_dataset` on the examples in the test batch represented by `inputs` and `targets`. This implementation does not require knowing the number of training examples in advance. Instead, the number of training examples is inferred from the @@ -360,12 +367,12 @@ def _influence( # type: ignore[override] Args: inputs (Tuple of Any): A batch of examples. Does not represent labels, which are passed as `targets`. The assumption is that - `self.model(*inputs)` produces the predictions for the batch. + `model(*inputs)` produces the predictions for the batch. targets (tensor): The labels corresponding to the batch `inputs`. This method is designed to be applied for a loss function, so labels are required. show_progress (bool, optional): To compute the influence of examples in - training dataset `influence_src_dataset`, we compute the influence + training dataset `train_dataset`, we compute the influence of each batch. If `show_progress`is true, the progress of this computation will be displayed. In particular, the number of batches for which influence has been computed will be displayed. It will @@ -376,31 +383,31 @@ def _influence( # type: ignore[override] Returns: influence_scores (tensor): Influence scores from the TracInCPFast method. - Its shape is `(input_size, influence_src_dataset_size)`, where `input_size` + Its shape is `(input_size, train_dataset_size)`, where `input_size` is the number of examples in the test batch, and - `influence_src_dataset_size` is the number of examples in - training dataset `influence_src_dataset`. For example: + `train_dataset_size` is the number of examples in + training dataset `train_dataset`. For example: `influence_scores[i][j]` is the influence score for the j-th training example to the i-th input example. """ assert targets is not None - influence_src_dataloader = self.influence_src_dataloader + train_dataloader = self.train_dataloader if show_progress: - influence_src_dataloader = progress( - influence_src_dataloader, + train_dataloader = progress( + train_dataloader, desc=( f"Using {self.get_name()} to compute " "influence for training batches" ), - total=self.influence_src_dataloader_len, + total=self.train_dataloader_len, ) return torch.cat( [ self._influence_batch_tracincp_fast(inputs, targets, batch) - for batch in influence_src_dataloader + for batch in train_dataloader ], dim=1, ) @@ -428,7 +435,7 @@ def _get_k_most_influential( # type: ignore[override] Default: True show_progress (bool, optional): To compute the proponents (or opponents) for the batch of examples, we perform computation for each batch in - training dataset `influence_src_dataset`, If `show_progress`is + training dataset `train_dataset`, If `show_progress`is true, the progress of this computation will be displayed. In particular, the number of batches for which the computation has been performed will be displayed. It will try to use tqdm if @@ -442,13 +449,13 @@ def _get_k_most_influential( # type: ignore[override] test example. Its dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the number of examples in `inputs`. For example, if `proponents==True`, `indices[i][j]` is the index of the - example in training dataset `influence_src_dataset` with the + example in training dataset `train_dataset` with the k-th highest influence score for the j-th example in `inputs`. `indices` is a `torch.long` tensor so that it can directly be used to index other tensors. Each row of `influence_scores` contains the influence scores for a different test example, in sorted order. In particular, `influence_scores[i][j]` is the influence score of - example `indices[i][j]` in training dataset `influence_src_dataset` + example `indices[i][j]` in training dataset `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -465,7 +472,7 @@ def _get_k_most_influential( # type: ignore[override] ) return KMostInfluentialResults( *_get_k_most_influential_helper( - self.influence_src_dataloader, + self.train_dataloader, self._influence_batch_tracincp_fast, inputs, targets, @@ -476,72 +483,141 @@ def _get_k_most_influential( # type: ignore[override] ) ) - def _self_influence_batch_tracincp_fast(self, batch: Tuple[Any, ...]): + def self_influence( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: """ - Computes self influence scores for a single batch + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Args: + batches (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. Please see documentation for the + `train_dataset` argument to `TracInCP.__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In more detail, this computation will iterate over all + checkpoints (provided as the `checkpoints` initialization argument) + in an outer loop, and iterate over all batches that + `inputs_dataset` represents in an inner loop. Therefore, the + total number of (checkpoint, batch) combinations that need to be + iterated over is + (# of checkpoints x # of batches that `inputs_dataset` represents). + If `show_progress` is True, the total progress of both the outer + iteration over checkpoints and the inner iteration over batches is + displayed. It will try to use tqdm if available for advanced + features (e.g. time estimation). Otherwise, it will fallback to a + simple output of progress. + Default: False + + Returns: + self_influence_scores (Tensor): This is a 1D tensor containing the self + influence scores of all examples in `inputs_dataset`, regardless of + whether it represents a single batch or a `DataLoader` that yields + batches. """ + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) - def get_checkpoint_contribution(checkpoint): + # If `show_progress` is true, create an outer progress bar that keeps track of + # how many checkpoints have been processed + if show_progress: + checkpoints_progress = progress( + desc=( + f"Using {self.get_name()} to compute self " + "influence. Processing checkpoint" + ), + total=len(self.checkpoints), + ) + # Try to determine length of inner progress bar if possible, with a default + # of `None`. + inputs_dataset_len = None + try: + inputs_dataset_len = len(inputs_dataset) + except TypeError: + warnings.warn( + "Unable to determine the number of batches in `inputs_dataset`. " + "Therefore, if showing the progress of the computation of self " + "influence scores, only the number of batches processed can be " + "displayed, and not the percentage completion of the computation, " + "nor any time estimates." + ) + def get_checkpoint_contribution(checkpoint): + # This function returns a 1D tensor representing the contribution to the + # self influence score for the given checkpoint, for all batches in + # `inputs_dataset`. The length of the 1D tensor is the total number of + # examples in `inputs_dataset`. assert ( checkpoint is not None ), "None returned from `checkpoints`, cannot load." learning_rate = self.checkpoints_load_func(self.model, checkpoint) - batch_jacobian, batch_layer_input = _basic_computation_tracincp_fast( - self, batch[0:-1], batch[-1] - ) + # This will store a list of the contribution of the self influence score + # from each batch. Each element is a 1D tensor of length batch_size - the + # batch size of each batch in `inputs_dataset` (they do not need to be all + # the same) + checkpoint_contribution = [] + + _inputs_dataset = inputs_dataset + # If `show_progress` is true, create an inner progress bar that keeps track + # of how many batches have been processed for the current checkpoint + if show_progress: + _inputs_dataset = progress( + inputs_dataset, + desc=( + f"Using {self.get_name()} to compute self " + "influence. Processing batch" + ), + total=inputs_dataset_len, + ) - return ( - torch.sum(batch_jacobian**2, dim=1) - * torch.sum(batch_layer_input**2, dim=1) - * learning_rate - ) + for batch in _inputs_dataset: - batch_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) + batch_jacobian, batch_layer_input = _basic_computation_tracincp_fast( + self, batch[0:-1], batch[-1] + ) - for checkpoint in self.checkpoints[1:]: - batch_self_tracin_scores += get_checkpoint_contribution(checkpoint) + checkpoint_contribution.append( + torch.sum(batch_jacobian**2, dim=1) + * torch.sum(batch_layer_input**2, dim=1) + * learning_rate + ) - return batch_self_tracin_scores + # We concatenate the contributions from each batch into a single 1D tensor, + # which represents the contributions for all batches in `inputs_dataset` - def _self_influence(self, show_progress: bool = False): - """ - Returns: - self influence scores (tensor): 1D tensor containing self influence - scores for all examples in training dataset - `influence_src_dataset`. - show_progress (bool, optional): To compute the self influence scores for - all examples in training dataset `influence_src_dataset`, we - compute the self influence scores for each batch. If - `show_progress`is true, the progress of this computation will be - displayed. In particular, the number of batches for which self - influence scores have been computed will be displayed. It will - try to use tqdm if available for advanced features (e.g. time - estimation). Otherwise, it will fallback to a simple output of - progress. - Default: False - """ - influence_src_dataloader = self.influence_src_dataloader + if show_progress: + checkpoints_progress.update() - if show_progress: - influence_src_dataloader = progress( - influence_src_dataloader, - desc=( - f"Using {self.get_name()} to compute self " - "influence for training batches" - ), - total=self.influence_src_dataloader_len, - ) + return torch.cat(checkpoint_contribution, dim=0) - return torch.cat( - [ - self._self_influence_batch_tracincp_fast(batch) - for batch in influence_src_dataloader - ], - dim=0, - ) + batches_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0]) + + # The self influence score for all examples is the sum of contributions from + # each checkpoint + for checkpoint in self.checkpoints[1:]: + batches_self_tracin_scores += get_checkpoint_contribution(checkpoint) + + return batches_self_tracin_scores def _basic_computation_tracincp_fast( @@ -564,7 +640,7 @@ def _basic_computation_tracincp_fast( inputs (Tuple of Any): A batch of examples, which could be a training batch or test batch, depending which method is the caller. Does not represent labels, which are passed as `targets`. The assumption is - that `self.model(*inputs)` produces the predictions for the batch. + that `model(*inputs)` produces the predictions for the batch. targets (tensor): If computing influence scores on a loss function, these are the labels corresponding to the batch `inputs`. """ @@ -599,7 +675,7 @@ def __init__( self, model: Module, final_fc_layer: Union[Module, str], - influence_src_dataset: Union[Dataset, DataLoader], + train_dataset: Union[Dataset, DataLoader], checkpoints: Union[str, List[str], Iterator], checkpoints_load_func: Callable = _load_flexible_state_dict, loss_fn: Optional[Union[Module, Callable]] = None, @@ -620,10 +696,10 @@ def __init__( interactive use cases. It should not be used if `influence` will only be called once, because to enable fast calls to `influence`, time and memory intensive preprocessing is required in `__init__`. Furthermore, it should not - be used to calculate self influencs scores - `TracInCPFast` should be used + be used to calculate self influence scores - `TracInCPFast` should be used instead for that purpose. To enable interactive analysis, this implementation - saves pre-computed vectors for all training examples in - `influence_src_dataset`. Crucially, the influence score of a training + computes and saves "embedding" vectors for all training examples in + `train_dataset`. Crucially, the influence score of a training example on a test example is simply the dot-product of their corresponding vectors, and proponents / opponents can be found by first storing vectors for training examples in a nearest-neighbor data structure, and then finding the @@ -631,7 +707,7 @@ def __init__( of the TracIn paper). This class should only be used if calls to `influence` to obtain proponents / opponents or influence scores will be made in an "interactive" manner, and there is sufficient memory to store vectors for the - entire `influence_src_dataset`. This is because in order to enable interactive + entire `train_dataset`. This is because in order to enable interactive analysis, this implementation incures overhead in ``__init__` to setup the nearest-neighbors data structure, which is both time and memory intensive, as vectors corresponding to all training examples needed to be stored. To reduce @@ -647,7 +723,7 @@ def __init__( projection method. Can be either the layer module itself, or the fully qualified name of the layer if it is a defined attribute of the passed `model`. - influence_src_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): + train_dataset (torch.utils.data.Dataset or torch.utils.DataLoader): In the `influence` method, we either compute the influence score of training examples on examples in a test batch, or self influence scores for those training examples, depending on which mode is used. @@ -662,9 +738,15 @@ def __init__( DataLoader used for processing should be as large as possible, but not too large, so that certain intermediate quantities created from a batch still fit in memory. Therefore, if - `influence_src_dataset` is a Dataset, `batch_size` should be large. - If `influence_src_dataset` was already a DataLoader to begin with, - it should have been constructed to have a large batch size. + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. checkpoints (str or List of str or Iterator): Either the directory of the path to store and retrieve model checkpoints, a list of filepaths with checkpoints from which to load, or an iterator which @@ -682,12 +764,12 @@ def __init__( `nn.BCELoss(reduction="mean")` is *not* acceptable. Default: None batch_size (int or None, optional): Batch size of the DataLoader created to - iterate through `influence_src_dataset`, if it is a Dataset. + iterate through `train_dataset`, if it is a Dataset. `batch_size` should be chosen as large as possible so that certain intermediate quantities created from a batch still fit in memory. Specific implementations of `TracInCPBase` will detail the size of the intermediate quantities. `batch_size` must be an int if - `influence_src_dataset` is a Dataset. If `influence_src_dataset` + `train_dataset` is a Dataset. If `train_dataset` is a DataLoader, then `batch_size` is ignored as an argument. Default: 1 vectorize (bool): Flag to use experimental vectorize functionality @@ -728,7 +810,7 @@ def __init__( self, model, final_fc_layer, - influence_src_dataset, + train_dataset, checkpoints, checkpoints_load_func, loss_fn, @@ -739,7 +821,7 @@ def __init__( warnings.warn( ( "WARNING: Using this implementation stores quantities related to the " - "entire `influence_src_dataset` in memory, and may results in running " + "entire `train_dataset` in memory, and may results in running " "out of memory. If this happens, consider using %s instead, for which " "each call to `influence` to compute influence scores or proponents " "will be slower, but may avoid running out of memory." @@ -755,12 +837,12 @@ def __init__( torch.manual_seed(seed) # for reproducibility self.projection_quantities = self._set_projections_tracincp_fast_rand_proj( - self.influence_src_dataloader, + self.train_dataloader, ) self.src_intermediate_quantities = ( self._get_intermediate_quantities_tracincp_fast_rand_proj( - self.influence_src_dataloader, + self.train_dataloader, self.projection_quantities, ) ) @@ -778,7 +860,7 @@ def _influence( # type: ignore[override] Args: inputs (tuple of Any): A batch of examples. Does not represent labels, which are passed as `targets`. The assumption is that - `self.model(*inputs)` produces the predictions for the batch. + `model(*inputs)` produces the predictions for the batch. targets (tensor): The labels corresponding to the batch `inputs`. This method is designed to be applied for a loss function, so labels are required. @@ -786,9 +868,9 @@ def _influence( # type: ignore[override] Returns: influence_scores (tensor): Influence scores from the TracInCPFastRandProj method. Its shape is - `(input_size, influence_src_dataset_size)`, where `input_size` is the - number of examples in the test batch, and `influence_src_dataset_size` is - the number of examples in training dataset `influence_src_dataset`. For + `(input_size, train_dataset_size)`, where `input_size` is the + number of examples in the test batch, and `train_dataset_size` is + the number of examples in training dataset `train_dataset`. For example, `influence_scores[i][j]` is the influence score for the j-th training example to the i-th input example. """ @@ -831,13 +913,13 @@ def _get_k_most_influential( # type: ignore[override] test example. Its dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the number of examples in `inputs`. For example, if `proponents==True`, `indices[i][j]` is the index of the - example in training dataset `influence_src_dataset` with the + example in training dataset `train_dataset` with the k-th highest influence score for the j-th example in `inputs`. `indices` is a `torch.long` tensor so that it can directly be used to index other tensors. Each row of `influence_scores` contains the influence scores for a different test example, in sorted order. In particular, `influence_scores[i][j]` is the influence score of - example `indices[i][j]` in training dataset `influence_src_dataset` + example `indices[i][j]` in training dataset `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -860,17 +942,55 @@ def _get_k_most_influential( # type: ignore[override] return KMostInfluentialResults(indices, distances) - def _self_influence(self): + def self_influence( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: """ - NOT IMPLEMENTED - no need to implement `TracInCPFastRandProj._self_influence`, - as `TracInCPFast._self_influence` is sufficient - the latter does not benefit + NOT IMPLEMENTED - no need to implement `TracInCPFastRandProj.self_influence`, + as `TracInCPFast.self_influence` is sufficient - the latter does not benefit from random projections, since no quantities associated with a training example are stored (other than its self influence score) + Computes self influence scores for a single batch or a Pytorch `DataLoader` + that yields batches. Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Args: + batches (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. Please see documentation for the + `train_dataset` argument to `TracInCP.__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In more detail, this computation will iterate over all + checkpoints (provided as the `checkpoints` initialization argument) + and all batches that `inputs_dataset` represents. Therefore, the + total number of (checkpoint, batch) combinations that need to be + iterated over is + (# of checkpoints x # of batches that `inputs_dataset` represents). + If `show_progress` is True, the total number of such combinations + that have been iterated over is displayed. It will try to use tqdm + if available for advanced features (e.g. time estimation). + Otherwise, it will fallback to a simple output of progress. + Default: False + Returns: - self influence scores (Tensor): 1-d Tensor containing self influence - scores for all examples in training dataset - `influence_src_dataset`. + self_influence_scores (Tensor): This is a 1D tensor containing the self + influence scores of all examples in `inputs_dataset`, regardless of + whether it represents a single batch or a `DataLoader` that yields + batches. """ warnings.warn( ( @@ -883,7 +1003,7 @@ def _self_influence(self): "`TracInCPFastRandProj`needed. Further considering the fact that " "random projections results only in approximate self influence " "scores, there is no reason to use `TracInCPFastRandProj` when " - "calculating self-influence scores." + "calculating self influence scores." ) ) raise NotImplementedError @@ -903,7 +1023,7 @@ def influence( # type: ignore[override] - influence score mode: This mode is used if `inputs` is not None, and `k` is None. This mode computes the influence score of every example in - training dataset `influence_src_dataset` on every example in the test + training dataset `train_dataset` on every example in the test batch represented by `inputs` and `targets`. - k-most influential mode: This mode is used if `inputs` is not None, and @@ -911,7 +1031,7 @@ def influence( # type: ignore[override] opponents of every example in the test batch represented by `inputs` and `targets`. In particular, for each test example in the test batch, this mode computes its proponents (resp. opponents), which are the - indices in the training dataset `influence_src_dataset` of the training + indices in the training dataset `train_dataset` of the training examples with the `k` highest (resp. lowest) influence scores on the test example. Proponents are computed if `proponents` is True. Otherwise, opponents are computed. For each test example, this method @@ -927,12 +1047,12 @@ def influence( # type: ignore[override] will be run. Otherwise, `inputs` is the test batch that will be used when running in either influence score or k-most influential mode. If the argument `unpack_inputs` is False, the - assumption is that `self.model(inputs)` produces the predictions + assumption is that `model(inputs)` produces the predictions for a batch, and `inputs` can be of any type. Otherwise if the argument `unpack_inputs` is True, the assumption is that - `self.model(*inputs)` produces the predictions for a batch, and + `model(*inputs)` produces the predictions for a batch, and `inputs` will need to be a tuple. In other words, `inputs` will be - unpacked as an argument when passing to `self.model`. + unpacked as an argument when passing to `model`. Default: None targets (tensor): The labels corresponding to the batch `inputs`. This method is designed to be applied for a loss function, so `targets` @@ -957,24 +1077,24 @@ def influence( # type: ignore[override] - influence score mode: if this mode is run (`inputs is not None, `k` is None), returns a 2D tensor `influence_scores` of shape - `(input_size, influence_src_dataset_size)`, where `input_size` is + `(input_size, train_dataset_size)`, where `input_size` is the number of examples in the test batch, and - `influence_src_dataset_size` is the number of examples in - training dataset `influence_src_dataset`. In other words, + `train_dataset_size` is the number of examples in + training dataset `train_dataset`. In other words, `influence_scores[i][j]` is the influence score of the `j`-th - example in `influence_src_dataset` on the `i`-th example in the + example in `train_dataset` on the `i`-th example in the test batch. - k-most influential mode: if this mode is run (`inputs` is not None, `k` is an int), returns a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of shape `(input_size, k)`, where `input_size` is the number of examples in the test batch. If computing proponents (resp. opponents), `indices[i][j]` is the - index in training dataset `influence_src_dataset` of the example + index in training dataset `train_dataset` of the example with the `j`-th highest (resp. lowest) influence score (out of the - examples in `influence_src_dataset`) on the `i`-th example in the + examples in `train_dataset`) on the `i`-th example in the test batch. `influence_scores` contains the corresponding influence scores. In particular, `influence_scores[i][j]` is the influence - score of example `indices[i][j]` in `influence_src_dataset` on + score of example `indices[i][j]` in `train_dataset` on example `i` in the test batch represented by `inputs` and `targets`. """ @@ -990,7 +1110,7 @@ def influence( # type: ignore[override] _inputs = _format_inputs(inputs, unpack_inputs) if inputs is None: - return self._self_influence() + return self.self_influence(self.train_dataloader) elif k is None: return self._influence(_inputs, targets) else: @@ -1014,7 +1134,7 @@ def _set_projections_tracincp_fast_rand_proj( dataloader (DataLoader): determining the projection requires knowing the dimensionality of the last layer's parameters (`jacobian_dim` below) and its input (`layer_input_dim` below). These are - determined by passing a batch to `self.model`. `dataloader` + determined by passing a batch to `model`. `dataloader` provides that batch. Returns: @@ -1096,7 +1216,7 @@ def _process_src_intermediate_quantities_tracincp_fast_rand_proj( Args: src_intermediate_quantities (tensor): the output of the `_get_intermediate_quantities_tracin_fast_rand_proj` function when - applied to training dataset `influence_src_dataset`. This + applied to training dataset `train_dataset`. This output is the vector representation of all training examples. The dot product between the representation of a training example and the representation of a test example gives the influence score @@ -1143,6 +1263,8 @@ def _get_intermediate_quantities_tracincp_fast_rand_proj( the variable d in the top of page 15 of the TracIn paper: https://arxiv.org/pdf/2002.08484.pdf. """ + # for each checkpoint, this stores a list of projections for a batch + # each element in this list will be of shape (batch_size, projection_dim) checkpoint_projections: List[Any] = [[] for _ in self.checkpoints] if projection_quantities is None: diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index b86ddf9f93..d6f1c99f20 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -313,3 +313,15 @@ def __getitem__(self, i: int) -> Any: def __len__(self) -> int: return len(self._l) + + +def _format_inputs_dataset(inputs_dataset: Union[Tuple[Any, ...], DataLoader]): + # if `inputs_dataset` is not a `DataLoader`, turn it into one. + # `_DatasetFromList` turns a list into a `Dataset` where `__getitem__` + # returns an element in the list, and using it to construct a `DataLoader` + # with `batch_size=None` gives a `DataLoader` that yields a single batch. + if not isinstance(inputs_dataset, DataLoader): + inputs_dataset = DataLoader( + _DatasetFromList([inputs_dataset]), shuffle=False, batch_size=None + ) + return inputs_dataset diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index 60f0be2678..9448982a58 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -12,6 +12,7 @@ DataInfluenceConstructor, get_random_model_and_data, ) +from torch.utils.data import DataLoader class TestTracInSelfInfluence(BaseTest): @@ -33,7 +34,7 @@ class TestTracInSelfInfluence(BaseTest): ("mean", DataInfluenceConstructor(TracInCPFast)), ] ], - name_func=build_test_name_func(args_to_skip=["reduction"]), + name_func=build_test_name_func(), ) def test_tracin_self_influence( self, reduction: str, tracin_constructor: Callable, unpack_inputs: bool @@ -73,3 +74,70 @@ def test_tracin_self_influence( delta=0.01, mode="max", ) + + @parameterized.expand( + [ + (reduction, constructor, unpack_inputs) + for unpack_inputs in [True, False] + for (reduction, constructor) in [ + ("none", DataInfluenceConstructor(TracInCP)), + ( + "sum", + DataInfluenceConstructor( + TracInCP, + sample_wise_grads_per_batch=True, + ), + ), + ("sum", DataInfluenceConstructor(TracInCPFast)), + ("mean", DataInfluenceConstructor(TracInCPFast)), + ] + ], + name_func=build_test_name_func(), + ) + def test_tracin_self_influence_dataloader_vs_single_batch( + self, reduction: str, tracin_constructor: Callable, unpack_inputs: bool + ) -> None: + # tests that the result of calling the public method `self_influence` for a + # DataLoader of batches is the same as when the batches are collated into a + # single batch + with tempfile.TemporaryDirectory() as tmpdir: + ( + net, + train_dataset, + ) = get_random_model_and_data(tmpdir, unpack_inputs, return_test_data=False) + + # create a single batch representing the entire dataset + single_batch = next( + iter(DataLoader(train_dataset, batch_size=len(train_dataset))) + ) + + # create a dataloader that yields batches from the dataset + dataloader = DataLoader(train_dataset, batch_size=5) + + # create tracin instance + criterion = nn.MSELoss(reduction=reduction) + batch_size = 5 + tracin = tracin_constructor( + net, + train_dataset, + tmpdir, + batch_size, + criterion, + ) + + # compute self influence using `self_influence` when passing in a single + # batch + single_batch_self_influence = tracin.self_influence(single_batch) + + # compute self influence using `self_influence` when passing in a + # dataloader with the same examples + dataloader_self_influence = tracin.self_influence(dataloader) + + # the two self influences should be equal + assertTensorAlmostEqual( + self, + single_batch_self_influence, + dataloader_self_influence, + delta=0.01, # due to numerical issues, we can't set this to 0.0 + mode="max", + ) diff --git a/tests/influence/_core/test_tracin_show_progress.py b/tests/influence/_core/test_tracin_show_progress.py index 5b35352880..17b9065458 100644 --- a/tests/influence/_core/test_tracin_show_progress.py +++ b/tests/influence/_core/test_tracin_show_progress.py @@ -49,115 +49,148 @@ class TestTracInShowProgress(BaseTest): ], name_func=build_test_name_func(args_to_skip=["reduction"]), ) - @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) def test_tracin_show_progress( self, reduction: str, tracin_constructor: Callable, mode: str, - mock_stderr, ) -> None: - with tempfile.TemporaryDirectory() as tmpdir: + with unittest.mock.patch("sys.stderr", new_callable=io.StringIO) as mock_stderr: - batch_size = 5 + with tempfile.TemporaryDirectory() as tmpdir: - ( - net, - train_dataset, - test_samples, - test_labels, - ) = get_random_model_and_data( - tmpdir, unpack_inputs=False, return_test_data=True - ) + batch_size = 5 - self.assertTrue(isinstance(reduction, str)) - criterion = nn.MSELoss(reduction=reduction) + ( + net, + train_dataset, + test_samples, + test_labels, + ) = get_random_model_and_data( + tmpdir, unpack_inputs=False, return_test_data=True + ) - self.assertTrue(callable(tracin_constructor)) - tracin = tracin_constructor( - net, - train_dataset, - tmpdir, - batch_size, - criterion, - ) + self.assertTrue(isinstance(reduction, str)) + criterion = nn.MSELoss(reduction=reduction) - if mode == "self influence": - tracin.influence(show_progress=True) - output = mock_stderr.getvalue() - self.assertTrue( - ( - ( - f"Using {tracin.get_name()} to compute self influence " - "for training batches: 100%" - ) - in output - ), - f"Error progress output: {repr(output)}", + self.assertTrue(callable(tracin_constructor)) + tracin = tracin_constructor( + net, + train_dataset, + tmpdir, + batch_size, + criterion, ) - elif mode == "influence": - tracin.influence( - test_samples, - test_labels, - k=None, - show_progress=True, - ) - output = mock_stderr.getvalue() - self.assertTrue( - ( - ( - f"Using {tracin.get_name()} to compute influence " - "for training batches: 100%" + if mode == "self influence": + + # For self influence, displaying progress involves nested progress + # bars, which are not currently supported by the backup + # `SimpleProgress` that is used if `tqdm` is not installed. + # Therefore, we skip the test in this case. + # TODO: support nested progress bars for `SimpleProgress` + try: + import tqdm # noqa + except ModuleNotFoundError: + raise unittest.SkipTest( + ( + "Skipping self influence progress bar tests for " + f"{tracin.get_name()}, because proper displaying " + "requires the tqdm module, which is not installed." + ) ) - in output - ), - f"Error progress output: {repr(output)}", - ) - elif mode == "k-most": - tracin.influence( - test_samples, - test_labels, - k=2, - proponents=True, - show_progress=True, - ) - output = mock_stderr.getvalue() - self.assertTrue( - ( + tracin.influence(show_progress=True) + output = mock_stderr.getvalue() + # We are showing nested progress bars for the `self_influence` + # method, with the outer progress bar over checkpoints, and + # the inner progress bar over batches. First, we check that + # the outer progress bar reaches 100% once + self.assertEqual( + output.count( + ( + f"Using {tracin.get_name()} to compute self influence. " + "Processing checkpoint: 100%" + ) + ), + 1, + f"Error in progress of batches with output: {repr(output)}", + ) + # Second, we check that the inner progress bar reaches 100% + # once for each checkpoint in `tracin.checkpoints` + self.assertEqual( + output.count( + ( + f"Using {tracin.get_name()} to compute self influence. " + "Processing batch: 100%" + ) + ), + len(tracin.checkpoints), + f"Error in progress of checkpoints with output: {repr(output)}", + ) + elif mode == "influence": + + tracin.influence( + test_samples, + test_labels, + k=None, + show_progress=True, + ) + output = mock_stderr.getvalue() + self.assertTrue( ( - f"Using {tracin.get_name()} to perform computation for " - "getting proponents. Processing training batches: 100%" - ) - in output - ), - f"Error progress output: {repr(output)}", - ) - mock_stderr.seek(0) - mock_stderr.truncate(0) + ( + f"Using {tracin.get_name()} to compute influence " + "for training batches: 100%" + ) + in output + ), + f"Error progress output: {repr(output)}", + ) + elif mode == "k-most": - tracin.influence( - test_samples, - test_labels, - k=2, - proponents=False, - show_progress=True, - ) - output = mock_stderr.getvalue() - self.assertTrue( - ( + tracin.influence( + test_samples, + test_labels, + k=2, + proponents=True, + show_progress=True, + ) + output = mock_stderr.getvalue() + self.assertTrue( ( - f"Using {tracin.get_name()} to perform computation for " - "getting opponents. Processing training batches: 100%" - ) - in output - ), - f"Error progress output: {repr(output)}", - ) - else: - raise Exception("unknown test mode") + ( + f"Using {tracin.get_name()} to perform computation for " + "getting proponents. Processing training batches: 100%" + ) + in output + ), + f"Error progress output: {repr(output)}", + ) + mock_stderr.seek(0) + mock_stderr.truncate(0) - mock_stderr.seek(0) - mock_stderr.truncate(0) + tracin.influence( + test_samples, + test_labels, + k=2, + proponents=False, + show_progress=True, + ) + output = mock_stderr.getvalue() + self.assertTrue( + ( + ( + f"Using {tracin.get_name()} to perform computation for " + "getting opponents. Processing training batches: 100%" + ) + in output + ), + f"Error progress output: {repr(output)}", + ) + else: + raise Exception("unknown test mode") + + mock_stderr.seek(0) + mock_stderr.truncate(0) From d3b1487af8e15db2e1500f5f1e7e47cb07bfe47b Mon Sep 17 00:00:00 2001 From: Fulton Wang Date: Sun, 31 Jul 2022 16:15:09 -0700 Subject: [PATCH 2/2] allow self influence iteration options (#1002) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1002 - For self influence computation, there needs to be an iteration over both checkpoints as well as batches. This diff adds a `by_checkpoints` option. If true, the outer iteration is over checkpoints. If false, the outer iteration is over checkpoints. Because self influence computation can be called through the `influence` and `self_influence` methods, this option is added to both methods. Because only `TracInCP` and `TracInCPFast` should be used for self influence computation, only those classes are changed. - The implement this option, the old `self_influence` method, which had the outer iteration over checkpoints, is renamed to be a private `_self_influence_by_checkpoints` method. A new `_self_influence_by_batches` method is added, which has an outer iteration over batches, and re-uses the `_self_influence_by_checkpoints` method to compute self influence scores for a single batch (this method can accept both a single batch, as well as a dataloader yielding batches). Because the logic of this method is the same for all classes, a helper method, `_self_influence_by_batches_helper`, is added to `captum.influence._utils.common`. Finally, the new `self_influence` method simply chooses whether to call `_self_influence_by_checkpoints` or `_self_influence_by_batches`. - Documentation describing the two options for `by_checkpoints` is added to the `self_influence` and `influence` methods. - `test_tracin_show_progress` now differentiates between 2 modes: "self influence by checkpoints" (the original test for progress bar when calculating self influence scores, which checks whether the outer progress bar over checkpoints and inner progress bars over batches both reach 100%), and the newly added mode "self influence by batches", which checks whether the progress bar over batches reaches 100%. - `test_tracin_self_influence` now also checks whether computing self influence scores gives the same result regardless of whether `by_checkpoints` is True or False Reviewed By: NarineK Differential Revision: D37743920 fbshipit-source-id: a4e0c44299b31bf50fe2b5b4cb4d2e62c669208a --- captum/influence/_core/tracincp.py | 102 +++++++++++--- .../_core/tracincp_fast_rand_proj.py | 104 ++++++++++++--- captum/influence/_utils/common.py | 94 +++++++++++++ .../_core/test_tracin_self_influence.py | 21 ++- .../_core/test_tracin_show_progress.py | 126 ++++++++++++------ 5 files changed, 372 insertions(+), 75 deletions(-) diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index 78fa32738f..15811e684b 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -30,6 +30,7 @@ _get_k_most_influential_helper, _gradient_dot_product, _load_flexible_state_dict, + _self_influence_by_batches_helper, ) from captum.log import log_usage from torch import Tensor @@ -475,7 +476,8 @@ def _influence_route_to_helpers( if inputs is None: return influence_instance.self_influence( - influence_instance.train_dataloader, show_progress + influence_instance.train_dataloader, + show_progress, ) elif k is None: return influence_instance._influence(_inputs, targets, show_progress) @@ -727,11 +729,9 @@ def influence( # type: ignore[override] requires "training dataset computations": computations for each batch in the training dataset `train_dataset`, which may take a long time. If `show_progress`is true, the progress of - "training dataset computations" will be displayed. In particular, - the number of batches for which computations have been performed - will be displayed. It will try to use tqdm if available for - advanced features (e.g. time estimation). Otherwise, it will - fallback to a simple output of progress. + "training dataset computations" will be displayed. It will try to + use tqdm if available for advanced features (e.g. time estimation). + Otherwise, it will fallback to a simple output of progress. Default: False Returns: @@ -926,7 +926,7 @@ def _get_k_most_influential( ( f"Using {self.get_name()} to perform computation for " f'getting {"proponents" if proponents else "opponents"}. ' - "Processing training batches: 100%" + "Processing training batches" ) ) ) @@ -943,7 +943,7 @@ def _get_k_most_influential( ) ) - def self_influence( + def _self_influence_by_checkpoints( self, inputs_dataset: Union[Tuple[Any, ...], DataLoader], show_progress: bool = False, @@ -957,7 +957,11 @@ def self_influence( will call `model` on that single batch, and if `inputs_dataset` yields batches, this will call `model` on each batch that is yielded. Therefore, please ensure that for both cases, the batch(es) that `model` is called - with are not too large, so that there will not be an out-of-memory error. + with are not too large, so that there will not be an out-of-memory error. This + implementation performs an outer iteration over checkpoints, and an inner + iteration over all batches that `inputs_dataset` represents. The pros of this + implementation are that the checkpoints do not need to be loaded too many + times. Args: batches (Tuple, or DataLoader): Either a single tuple of any, or a @@ -976,13 +980,10 @@ def self_influence( displayed. In more detail, this computation will iterate over all checkpoints (provided as the `checkpoints` initialization argument) in an outer loop, and iterate over all batches that - `inputs_dataset` represents in an inner loop. Therefore, the - total number of (checkpoint, batch) combinations that need to be - iterated over is - (# of checkpoints x # of batches that `inputs_dataset` represents). - If `show_progress` is True, the total progress of both the outer - iteration over checkpoints and the inner iteration over batches is - displayed. It will try to use tqdm if available for advanced + `inputs_dataset` represents in an inner loop. Thus if + `show_progress` is True, the progress of both the outer + iteration and the inner iterations will be displayed. To show + progress, it will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False @@ -1097,6 +1098,75 @@ def get_checkpoint_contribution(checkpoint): return batches_self_tracin_scores + def self_influence( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + outer_loop_by_checkpoints: bool = False, + ) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + Internally, this computation requires iterating both over the batches in + `inputs_dataset`, as well as different model checkpoints. There are two ways + this iteration can be done. If `outer_loop_by_checkpoints` is False, the outer + iteration will be over batches, and the inner iteration will be over + checkpoints. This has the pro that displaying the progress of the computation + is more intuitive, involving displaying the number of batches for which self + influence scores have been computed. If `outer_loop_by_checkpoints` is True, + the outer iteration will be over checkpoints, and the inner iteration will be + over batches. This has the pro that the checkpoints do not need to be loaded + for each batch. For large models, loading checkpoints can be time-intensive. + + Args: + batches (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. Please see documentation for the + `train_dataset` argument to `TracInCP.__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In more detail, if `outer_loop_by_checkpoints` is False, + this computation will iterate over all batches in an outer loop. + Thus if `show_progress` is True, the number of batches for which + self influence scores have been computed will be displayed. If + `outer_loop_by_checkpoints` is True, this computation will iterate + over all checkpoints (provided as the `checkpoints` initialization + argument) in an outer loop, and iterate over all batches that + `inputs_dataset` represents in an inner loop. Thus if + `show_progress` is True, the progress of both the outer + iteration and the inner iterations will be displayed. To show + progress, it will try to use tqdm if available for advanced + features (e.g. time estimation). Otherwise, it will fallback to a + simple output of progress. + Default: False + outer_loop_by_checkpoints (bool, optional): If performing an outer + iteration over checkpoints; see method description for more + details. + Default: False + """ + if outer_loop_by_checkpoints: + return self._self_influence_by_checkpoints(inputs_dataset, show_progress) + return _self_influence_by_batches_helper( + self._self_influence_by_checkpoints, + self.get_name(), + inputs_dataset, + show_progress, + ) + def _basic_computation_tracincp( self, inputs: Tuple[Any, ...], diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index 71fe3b45a0..f42dbd1527 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -17,6 +17,7 @@ _get_k_most_influential_helper, _jacobian_loss_wrt_inputs, _load_flexible_state_dict, + _self_influence_by_batches_helper, _tensor_batch_dot, ) from captum.influence._utils.nearest_neighbors import ( @@ -263,11 +264,9 @@ def influence( # type: ignore[override] requires "training dataset computations": computations for each batch in the training dataset `train_dataset`, which may take a long time. If `show_progress`is true, the progress of - "training dataset computations" will be displayed. In particular, - the number of batches for which computations have been performed - will be displayed. It will try to use tqdm if available for - advanced features (e.g. time estimation). Otherwise, it will - fallback to a simple output of progress. + "training dataset computations" will be displayed. It will try to + use tqdm if available for advanced features (e.g. time estimation). + Otherwise, it will fallback to a simple output of progress. Default: False Returns: @@ -466,7 +465,7 @@ def _get_k_most_influential( # type: ignore[override] ( f"Using {self.get_name()} to perform computation for " f'getting {"proponents" if proponents else "opponents"}. ' - "Processing training batches: 100%" + "Processing training batches" ) ) ) @@ -483,7 +482,7 @@ def _get_k_most_influential( # type: ignore[override] ) ) - def self_influence( + def _self_influence_by_checkpoints( self, inputs_dataset: Union[Tuple[Any, ...], DataLoader], show_progress: bool = False, @@ -497,7 +496,11 @@ def self_influence( will call `model` on that single batch, and if `inputs_dataset` yields batches, this will call `model` on each batch that is yielded. Therefore, please ensure that for both cases, the batch(es) that `model` is called - with are not too large, so that there will not be an out-of-memory error. + with are not too large, so that there will not be an out-of-memory error. This + implementation performs an outer iteration over checkpoints, and an inner + iteration over all batches that `inputs_dataset` represents. The pros of this + implementation are that the checkpoints do not need to be loaded too many + times. Args: batches (Tuple, or DataLoader): Either a single tuple of any, or a @@ -516,13 +519,10 @@ def self_influence( displayed. In more detail, this computation will iterate over all checkpoints (provided as the `checkpoints` initialization argument) in an outer loop, and iterate over all batches that - `inputs_dataset` represents in an inner loop. Therefore, the - total number of (checkpoint, batch) combinations that need to be - iterated over is - (# of checkpoints x # of batches that `inputs_dataset` represents). - If `show_progress` is True, the total progress of both the outer - iteration over checkpoints and the inner iteration over batches is - displayed. It will try to use tqdm if available for advanced + `inputs_dataset` represents in an inner loop. Thus if + `show_progress` is True, the progress of both the outer + iteration and the inner iterations will be displayed. To show + progress, it will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False @@ -619,6 +619,75 @@ def get_checkpoint_contribution(checkpoint): return batches_self_tracin_scores + def self_influence( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + outer_loop_by_checkpoints: bool = False, + ) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + Internally, this computation requires iterating both over the batches in + `inputs_dataset`, as well as different model checkpoints. There are two ways + this iteration can be done. If `outer_loop_by_checkpoints` is False, the outer + iteration will be over batches, and the inner iteration will be over + checkpoints. This has the pro that displaying the progress of the computation + is more intuitive, involving displaying the number of batches for which self + influence scores have been computed. If `outer_loop_by_checkpoints` is True, + the outer iteration will be over checkpoints, and the inner iteration will be + over batches. This has the pro that the checkpoints do not need to be loaded + for each batch. For large models, loading checkpoints can be time-intensive. + + Args: + batches (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. Please see documentation for the + `train_dataset` argument to `TracInCP.__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In more detail, if `outer_loop_by_checkpoints` is False, + this computation will iterate over all batches in an outer loop. + Thus if `show_progress` is True, the number of batches for which + self influence scores have been computed will be displayed. If + `outer_loop_by_checkpoints` is True, this computation will iterate + over all checkpoints (provided as the `checkpoints` initialization + argument) in an outer loop, and iterate over all batches that + `inputs_dataset` represents in an inner loop. Thus if + `show_progress` is True, the progress of both the outer + iteration and the inner iterations will be displayed. To show + progress, it will try to use tqdm if available for advanced + features (e.g. time estimation). Otherwise, it will fallback to a + simple output of progress. + Default: False + outer_loop_by_checkpoints (bool, optional): If performing an outer + iteration over checkpoints; see method description for more + details. + Default: False + """ + if outer_loop_by_checkpoints: + return self._self_influence_by_checkpoints(inputs_dataset, show_progress) + return _self_influence_by_batches_helper( + self._self_influence_by_checkpoints, + self.get_name(), + inputs_dataset, + show_progress, + ) + def _basic_computation_tracincp_fast( influence_instance: TracInCPFast, @@ -946,6 +1015,7 @@ def self_influence( self, inputs_dataset: Union[Tuple[Any, ...], DataLoader], show_progress: bool = False, + outer_loop_by_checkpoints: bool = False, ) -> Tensor: """ NOT IMPLEMENTED - no need to implement `TracInCPFastRandProj.self_influence`, @@ -985,6 +1055,10 @@ def self_influence( if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False + outer_loop_by_checkpoints (bool, optional): If performing an outer + iteration over checkpoints; see method description for more + details. + Default: False Returns: self_influence_scores (Tensor): This is a 1D tensor containing the self diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index d6f1c99f20..131f8964b8 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import warnings from typing import Any, Callable, List, Optional, Tuple, Union import torch @@ -325,3 +326,96 @@ def _format_inputs_dataset(inputs_dataset: Union[Tuple[Any, ...], DataLoader]): _DatasetFromList([inputs_dataset]), shuffle=False, batch_size=None ) return inputs_dataset + + +def _self_influence_by_batches_helper( + self_influence_batch_fn: Callable, + instance_name: str, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, +) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. The self + influence scores for a single batch are computed using the + `self_influence_batch_fn` input. Note that if `inputs_dataset` is a single batch, + this will call `model` on that single batch, where `model` is the model used to + compute self influence scores by `self_influence_batch_fn`, and if `inputs_dataset` + yields batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. This + implementation performs an outer iteration over all batches that + `inputs_dataset` represents, and an inner iteration over checkpoints. The pros + of this implementation are that showing the progress of the computation is + straightforward. + + Args: + self_influence_batch_fn (Callable): This is the function that computes self + influence scores for a single batch. + instance_name (str): This is the name of the implementation class that + `self_influence_batch_fn` is a method of. This is used for displaying + warning messages. + batches (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. Please see documentation for the + `train_dataset` argument to `TracInCP.__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which self + influence scores have been computed will be displayed. It will try + to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + + Returns: + self_influence_scores (Tensor): This is a 1D tensor containing the self + influence scores of all examples in `inputs_dataset`, regardless of + whether it represents a single batch or a `DataLoader` that yields + batches. + """ + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + # If `show_progress` is true, create a progress bar that keeps track of how + # many batches have been processed + if show_progress: + # First, try to determine length of progress bar if possible, with a + # default of `None` + inputs_dataset_len = None + try: + inputs_dataset_len = len(inputs_dataset) + except TypeError: + warnings.warn( + "Unable to determine the number of batches in `inputs_dataset`. " + "Therefore, if showing the progress of the computation of self " + "influence scores, only the number of batches processed can be " + "displayed, and not the percentage completion of the computation, " + "nor any time estimates." + ) + # then create the progress bar + inputs_dataset = progress( + inputs_dataset, + desc=f"Using {instance_name} to compute self influence. Processing batch", + total=inputs_dataset_len, + ) + + # To compute self influence scores for each batch, we use + # `_self_influence_by_checkpoints`, which can accept a tuple representing a + # single batch as the `inputs_dataset` argument (as well as a DataLoader). + # Because we are already displaying progress in terms of number of batches + # processed in this method, we will not show progress for the call to + # `_self_influence_by_checkpoints`. + return torch.cat( + [ + self_influence_batch_fn(batch, show_progress=False) + for batch in inputs_dataset + ] + ) diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index 9448982a58..0f327ce3fb 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -57,6 +57,7 @@ def test_tracin_self_influence( criterion, ) + # calculate influence scores, using the training data as the test batch train_scores = tracin.influence( train_dataset.samples, train_dataset.labels, @@ -65,8 +66,12 @@ def test_tracin_self_influence( ) # calculate self_tracin_scores - self_tracin_scores = tracin.influence() + self_tracin_scores = tracin.self_influence( + DataLoader(train_dataset, batch_size=batch_size), + outer_loop_by_checkpoints=False, + ) + # check that self_tracin scores equals the diagonal of influence scores assertTensorAlmostEqual( self, torch.diagonal(train_scores), @@ -75,6 +80,20 @@ def test_tracin_self_influence( mode="max", ) + # check that setting `outer_loop_by_checkpoints=False` and + # `outer_loop_by_checkpoints=True` gives the same self influence scores + self_tracin_scores_by_checkpoints = tracin.self_influence( + DataLoader(train_dataset, batch_size=batch_size), + outer_loop_by_checkpoints=True, + ) + assertTensorAlmostEqual( + self, + self_tracin_scores_by_checkpoints, + self_tracin_scores, + delta=0.01, + mode="max", + ) + @parameterized.expand( [ (reduction, constructor, unpack_inputs) diff --git a/tests/influence/_core/test_tracin_show_progress.py b/tests/influence/_core/test_tracin_show_progress.py index 17b9065458..e940e2ed66 100644 --- a/tests/influence/_core/test_tracin_show_progress.py +++ b/tests/influence/_core/test_tracin_show_progress.py @@ -14,6 +14,7 @@ DataInfluenceConstructor, get_random_model_and_data, ) +from torch.utils.data import DataLoader class TestTracInShowProgress(BaseTest): @@ -28,6 +29,18 @@ class TestTracInShowProgress(BaseTest): in `TracInCPFastRandProj.__init__`). """ + def _check_error_msg_multiplicity(self, mock_stderr, msg, msg_multiplicity): + """ + checks that in `mock_stderr`, the error msg `msg` occurs `msg_multiplicity` + times + """ + output = mock_stderr.getvalue() + self.assertEqual( + output.count(msg), + msg_multiplicity, + f"Error in progress of batches with output: {repr(output)}", + ) + @parameterized.expand( [ ( @@ -45,7 +58,12 @@ class TestTracInShowProgress(BaseTest): DataInfluenceConstructor(TracInCPFast), ), ] - for mode in ["self influence", "influence", "k-most"] + for mode in [ + "self influence by checkpoints", + "self influence by batches", + "influence", + "k-most", + ] ], name_func=build_test_name_func(args_to_skip=["reduction"]), ) @@ -83,9 +101,13 @@ def test_tracin_show_progress( criterion, ) - if mode == "self influence": + if mode == "self influence by checkpoints": + # this tests progress for computing self influence scores, when + # `outer_loop_by_checkpoints` is True. In this case, we should see a + # single outer progress bar over checkpoints, and for every + # checkpoints, a separate progress bar over batches - # For self influence, displaying progress involves nested progress + # In this case, displaying progress involves nested progress # bars, which are not currently supported by the backup # `SimpleProgress` that is used if `tqdm` is not installed. # Therefore, we skip the test in this case. @@ -101,33 +123,50 @@ def test_tracin_show_progress( ) ) - tracin.influence(show_progress=True) - output = mock_stderr.getvalue() + tracin.self_influence( + DataLoader(train_dataset, batch_size=batch_size), + show_progress=True, + outer_loop_by_checkpoints=True, + ) + # We are showing nested progress bars for the `self_influence` # method, with the outer progress bar over checkpoints, and # the inner progress bar over batches. First, we check that # the outer progress bar reaches 100% once - self.assertEqual( - output.count( - ( - f"Using {tracin.get_name()} to compute self influence. " - "Processing checkpoint: 100%" - ) + self._check_error_msg_multiplicity( + mock_stderr, + ( + f"Using {tracin.get_name()} to compute self influence. " + "Processing checkpoint: 100%" ), 1, - f"Error in progress of batches with output: {repr(output)}", ) # Second, we check that the inner progress bar reaches 100% # once for each checkpoint in `tracin.checkpoints` - self.assertEqual( - output.count( - ( - f"Using {tracin.get_name()} to compute self influence. " - "Processing batch: 100%" - ) + self._check_error_msg_multiplicity( + mock_stderr, + ( + f"Using {tracin.get_name()} to compute self influence. " + "Processing batch: 100%" ), len(tracin.checkpoints), - f"Error in progress of checkpoints with output: {repr(output)}", + ) + elif mode == "self influence by batches": + # This tests progress for computing self influence scores, when + # `outer_loop_by_checkpoints` is False. In this case, we should see + # a single outer progress bar over batches. + tracin.self_influence( + DataLoader(train_dataset, batch_size=batch_size), + show_progress=True, + outer_loop_by_checkpoints=False, + ) + self._check_error_msg_multiplicity( + mock_stderr, + ( + f"Using {tracin.get_name()} to compute self influence. " + "Processing batch: 100%" + ), + 1, ) elif mode == "influence": @@ -137,16 +176,15 @@ def test_tracin_show_progress( k=None, show_progress=True, ) - output = mock_stderr.getvalue() - self.assertTrue( + # Since the computation iterates once over training batches, we + # check that the progress bar over batches reaches 100% once + self._check_error_msg_multiplicity( + mock_stderr, ( - ( - f"Using {tracin.get_name()} to compute influence " - "for training batches: 100%" - ) - in output + f"Using {tracin.get_name()} to compute influence " + "for training batches: 100%" ), - f"Error progress output: {repr(output)}", + 1, ) elif mode == "k-most": @@ -157,16 +195,17 @@ def test_tracin_show_progress( proponents=True, show_progress=True, ) - output = mock_stderr.getvalue() - self.assertTrue( + + # Since the computation iterates once over training batches, we + # check that the progress bar over batches reaches 100% once, and + # that the message is specific for finding proponents. + self._check_error_msg_multiplicity( + mock_stderr, ( - ( - f"Using {tracin.get_name()} to perform computation for " - "getting proponents. Processing training batches: 100%" - ) - in output + f"Using {tracin.get_name()} to perform computation for " + "getting proponents. Processing training batches: 100%" ), - f"Error progress output: {repr(output)}", + 1, ) mock_stderr.seek(0) mock_stderr.truncate(0) @@ -178,16 +217,17 @@ def test_tracin_show_progress( proponents=False, show_progress=True, ) - output = mock_stderr.getvalue() - self.assertTrue( + + # Since the computation iterates once over training batches, we + # check that the progress bar over batches reaches 100% once, and + # that the message is specific for finding opponents. + self._check_error_msg_multiplicity( + mock_stderr, ( - ( - f"Using {tracin.get_name()} to perform computation for " - "getting opponents. Processing training batches: 100%" - ) - in output + f"Using {tracin.get_name()} to perform computation for " + "getting opponents. Processing training batches: 100%" ), - f"Error progress output: {repr(output)}", + 1, ) else: raise Exception("unknown test mode")