2
2
3
3
# pyre-strict
4
4
5
- from typing import Optional , Protocol , Tuple , Type
5
+ from typing import Any , Dict , Optional , Protocol , Tuple , Type
6
6
7
7
import torch
8
8
9
+ from packaging .version import Version
10
+ from torch import nn
11
+
9
12
10
13
class CacheLike (Protocol ):
11
14
"""Protocol for cache-like objects."""
@@ -21,12 +24,91 @@ def from_legacy_cache(
21
24
) -> "DynamicCacheLike" : ...
22
25
23
26
27
+ transformers_installed : bool
28
+ Cache : Optional [Type [CacheLike ]]
29
+ DynamicCache : Optional [Type [DynamicCacheLike ]]
30
+
24
31
try :
25
- # pyre-ignore[21]: Could not find a module corresponding to import
26
- # `transformers.cache_utils`
27
- from transformers . cache_utils import Cache as _Cache , DynamicCache as _DynamicCache
32
+ import transformers # noqa: F401
33
+
34
+ transformers_installed = True
28
35
except ImportError :
29
- _Cache = _DynamicCache = None
36
+ transformers_installed = False
37
+
38
+ if transformers_installed :
39
+ try :
40
+ from transformers .cache_utils import ( # noqa: F401
41
+ Cache as _Cache ,
42
+ DynamicCache as _DynamicCache ,
43
+ )
44
+
45
+ Cache = _Cache
46
+ # pyre-ignore[9]: Incompatible variable type: DynamicCache is declared to have
47
+ # type `Optional[Type[DynamicCacheLike]]` but is used as type
48
+ # `Type[_DynamicCache]`
49
+ DynamicCache = _DynamicCache
50
+ except ImportError :
51
+ Cache = DynamicCache = None
52
+ else :
53
+ Cache = DynamicCache = None
54
+
55
+ # GenerationMixin._update_model_kwargs_for_generation
56
+ # "cache_position" at v4.39.0 (only needed for models that support cache class)
57
+ # "use_cache" at v4.41.0 (optional, default is True)
58
+ # "cache_position" is mandatory at v4.43.0 ("use_cache" is still optional, default True)
59
+ _transformers_version : Optional [Version ]
60
+ if transformers_installed :
61
+ _transformers_version = Version (transformers .__version__ )
62
+ else :
63
+ _transformers_version = None
64
+
65
+ _mandated_cache_version = Version ("4.43.0" )
66
+ _use_cache_version = Version ("4.41.0" )
67
+ _cache_position_version = Version ("4.39.0" )
68
+
69
+
70
+ def update_model_kwargs (
71
+ model_kwargs : Dict [str , Any ],
72
+ model : nn .Module ,
73
+ input_ids : torch .Tensor ,
74
+ caching : bool ,
75
+ ) -> None :
76
+ if not supports_caching (model ):
77
+ return
78
+ assert _transformers_version is not None
79
+ if caching :
80
+ # Enable caching
81
+ if _transformers_version >= _cache_position_version :
82
+ cache_position = torch .arange (
83
+ input_ids .shape [1 ], dtype = torch .int64 , device = input_ids .device
84
+ )
85
+ model_kwargs ["cache_position" ] = cache_position
86
+ # pyre-ignore[58]: Unsupported operand `>=` is not supported for operand types
87
+ # `Optional[Version]` and `Version`.
88
+ if _transformers_version >= _use_cache_version :
89
+ model_kwargs ["use_cache" ] = True
90
+ else :
91
+ # Disable caching
92
+ if _transformers_version >= _use_cache_version :
93
+ model_kwargs ["use_cache" ] = False
94
+
30
95
31
- Cache : Optional [Type [CacheLike ]] = _Cache
32
- DynamicCache : Optional [Type [DynamicCacheLike ]] = _DynamicCache
96
+ def supports_caching (model : nn .Module ) -> bool :
97
+ if not transformers_installed :
98
+ # Not a transformers model
99
+ return False
100
+ # Cache may be optional or unsupported depending on model/version
101
+ try :
102
+ from transformers .generation .utils import GenerationMixin
103
+ except ImportError :
104
+ return False
105
+ if not isinstance (model , GenerationMixin ):
106
+ # Model isn't a GenerationMixin, we don't support additional caching logic
107
+ # for it
108
+ return False
109
+ assert _transformers_version is not None
110
+ if _transformers_version >= _mandated_cache_version :
111
+ # Cache is mandatory
112
+ return True
113
+ # Fallback on _supports_cache_class attribute
114
+ return getattr (model , "_supports_cache_class" , False )
0 commit comments