Skip to content

Commit 8f22fa1

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Support Cache Class for Even Newer Versions of Transformers Library (#1343)
Summary: Pull Request resolved: #1343 Supports multiple and newer versions of the transformers library. Adds the `packaging` dependency as well to more robustly check package versions. Differential Revision: D62468332
1 parent 8282b01 commit 8f22fa1

File tree

3 files changed

+123
-11
lines changed

3 files changed

+123
-11
lines changed

captum/_utils/transformers_typing.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
# pyre-strict
44

5-
from typing import Optional, Protocol, Tuple, Type
5+
from typing import Any, Dict, Optional, Protocol, Tuple, Type
66

77
import torch
88

9+
from packaging.version import Version
10+
from torch import nn
11+
912

1013
class CacheLike(Protocol):
1114
"""Protocol for cache-like objects."""
@@ -21,12 +24,96 @@ def from_legacy_cache(
2124
) -> "DynamicCacheLike": ...
2225

2326

27+
transformers_installed: bool
28+
Cache: Optional[Type[CacheLike]]
29+
DynamicCache: Optional[Type[DynamicCacheLike]]
30+
2431
try:
25-
# pyre-ignore[21]: Could not find a module corresponding to import
26-
# `transformers.cache_utils`
27-
from transformers.cache_utils import Cache as _Cache, DynamicCache as _DynamicCache
32+
# pyre-ignore[21]: Could not find a module corresponding to import `transformers`.
33+
import transformers # noqa: F401
34+
35+
transformers_installed = True
2836
except ImportError:
29-
_Cache = _DynamicCache = None
37+
transformers_installed = False
38+
39+
if transformers_installed:
40+
try:
41+
# pyre-ignore[21]: Could not find a module corresponding to import
42+
# `transformers.cache_utils`.
43+
from transformers.cache_utils import ( # noqa: F401
44+
Cache as _Cache,
45+
DynamicCache as _DynamicCache,
46+
)
47+
48+
Cache = _Cache
49+
# pyre-ignore[9]: Incompatible variable type: DynamicCache is declared to have
50+
# type `Optional[Type[DynamicCacheLike]]` but is used as type
51+
# `Type[_DynamicCache]`
52+
DynamicCache = _DynamicCache
53+
except ImportError:
54+
Cache = DynamicCache = None
55+
else:
56+
Cache = DynamicCache = None
57+
58+
# GenerationMixin._update_model_kwargs_for_generation
59+
# "cache_position" at v4.39.0 (only needed for models that support cache class)
60+
# "use_cache" at v4.41.0 (optional, default is True)
61+
# "cache_position" is mandatory at v4.43.0 ("use_cache" is still optional, default True)
62+
_transformers_version: Optional[Version]
63+
if transformers_installed:
64+
_transformers_version = Version(transformers.__version__)
65+
else:
66+
_transformers_version = None
67+
68+
_mandated_cache_version = Version("4.43.0")
69+
_use_cache_version = Version("4.41.0")
70+
_cache_position_version = Version("4.39.0")
71+
72+
73+
def update_model_kwargs(
74+
model_kwargs: Dict[str, Any],
75+
model: nn.Module,
76+
input_ids: torch.Tensor,
77+
caching: bool,
78+
) -> None:
79+
if not supports_caching(model):
80+
return
81+
assert _transformers_version is not None
82+
if caching:
83+
# Enable caching
84+
if _transformers_version >= _cache_position_version:
85+
cache_position = torch.arange(
86+
input_ids.shape[1], dtype=torch.int64, device=input_ids.device
87+
)
88+
model_kwargs["cache_position"] = cache_position
89+
# pyre-ignore[58]: Unsupported operand `>=` is not supported for operand types
90+
# `Optional[Version]` and `Version`.
91+
if _transformers_version >= _use_cache_version:
92+
model_kwargs["use_cache"] = True
93+
else:
94+
# Disable caching
95+
if _transformers_version >= _use_cache_version:
96+
model_kwargs["use_cache"] = False
97+
3098

31-
Cache: Optional[Type[CacheLike]] = _Cache
32-
DynamicCache: Optional[Type[DynamicCacheLike]] = _DynamicCache
99+
def supports_caching(model: nn.Module) -> bool:
100+
if not transformers_installed:
101+
# Not a transformers model
102+
return False
103+
# Cache may be optional or unsupported depending on model/version
104+
try:
105+
# pyre-ignore[21]: Could not find a module corresponding to import
106+
# `transformers.generation.utils`.
107+
from transformers.generation.utils import GenerationMixin
108+
except ImportError:
109+
return False
110+
if not isinstance(model, GenerationMixin):
111+
# Model isn't a GenerationMixin, we don't support additional caching logic
112+
# for it
113+
return False
114+
assert _transformers_version is not None
115+
if _transformers_version >= _mandated_cache_version:
116+
# Cache is mandatory
117+
return True
118+
# Fallback on _supports_cache_class attribute
119+
return getattr(model, "_supports_cache_class", False)

captum/attr/_core/llm_attr.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88

99
import torch
10-
from captum._utils.transformers_typing import Cache, DynamicCache
1110
from captum._utils.typing import TokenizerLike
1211
from captum.attr._core.feature_ablation import FeatureAblation
1312
from captum.attr._core.kernel_shap import KernelShap
@@ -259,6 +258,15 @@ def _forward_func(
259258
use_cached_outputs: bool = False,
260259
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
261260
) -> Tensor:
261+
# Lazily import transformers_typing to avoid importing transformers package if
262+
# it isn't needed
263+
from captum._utils.transformers_typing import (
264+
Cache,
265+
DynamicCache,
266+
supports_caching,
267+
update_model_kwargs,
268+
)
269+
262270
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
263271
init_model_inp = perturbed_input
264272

@@ -267,16 +275,25 @@ def _forward_func(
267275
[1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device
268276
)
269277
model_kwargs = {"attention_mask": attention_mask}
278+
# If applicable, update model kwargs for transformers models
279+
update_model_kwargs(
280+
model_kwargs=model_kwargs,
281+
model=self.model,
282+
input_ids=model_inp,
283+
caching=use_cached_outputs,
284+
)
270285

271286
log_prob_list = []
272287
outputs = None
273288
for target_token in target_tokens:
274289
if use_cached_outputs:
275290
if outputs is not None:
291+
# If applicable, convert past_key_values to DynamicCache for
292+
# transformers models
276293
if (
277294
Cache is not None
278295
and DynamicCache is not None
279-
and getattr(self.model, "_supports_cache_class", False)
296+
and supports_caching(self.model)
280297
and not isinstance(outputs.past_key_values, Cache)
281298
):
282299
outputs.past_key_values = DynamicCache.from_legacy_cache(

setup.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def report(*args):
8282

8383
# get version string from module
8484
with open(os.path.join(os.path.dirname(__file__), "captum/__init__.py"), "r") as f:
85-
version = re.search(r"__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M).group(1)
85+
version_match = re.search(r"__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
86+
assert version_match is not None, "Unable to find version string."
87+
version = version_match.group(1)
8688
report("-- Building version " + version)
8789

8890
# read in README.md as the long description
@@ -147,7 +149,13 @@ def get_package_files(root, subdirs):
147149
long_description=long_description,
148150
long_description_content_type="text/markdown",
149151
python_requires=">=3.8",
150-
install_requires=["matplotlib", "numpy<2.0", "torch>=1.10", "tqdm"],
152+
install_requires=[
153+
"matplotlib",
154+
"numpy<2.0",
155+
"packaging",
156+
"torch>=1.10",
157+
"tqdm",
158+
],
151159
packages=find_packages(exclude=("tests", "tests.*")),
152160
extras_require={
153161
"dev": DEV_REQUIRES,

0 commit comments

Comments
 (0)