2
2
3
3
# pyre-strict
4
4
5
- from typing import Optional , Protocol , Tuple , Type
5
+ from typing import Any , Dict , Optional , Protocol , Tuple , Type
6
6
7
7
import torch
8
8
9
+ from packaging .version import Version
10
+ from torch import nn
11
+
9
12
10
13
class CacheLike (Protocol ):
11
14
"""Protocol for cache-like objects."""
@@ -21,12 +24,96 @@ def from_legacy_cache(
21
24
) -> "DynamicCacheLike" : ...
22
25
23
26
27
+ transformers_installed : bool
28
+ Cache : Optional [Type [CacheLike ]]
29
+ DynamicCache : Optional [Type [DynamicCacheLike ]]
30
+
24
31
try :
25
- # pyre-ignore[21]: Could not find a module corresponding to import
26
- # `transformers.cache_utils`
27
- from transformers .cache_utils import Cache as _Cache , DynamicCache as _DynamicCache
32
+ # pyre-ignore[21]: Could not find a module corresponding to import `transformers`.
33
+ import transformers # noqa: F401
34
+
35
+ transformers_installed = True
28
36
except ImportError :
29
- _Cache = _DynamicCache = None
37
+ transformers_installed = False
38
+
39
+ if transformers_installed :
40
+ try :
41
+ # pyre-ignore[21]: Could not find a module corresponding to import
42
+ # `transformers.cache_utils`.
43
+ from transformers .cache_utils import ( # noqa: F401
44
+ Cache as _Cache ,
45
+ DynamicCache as _DynamicCache ,
46
+ )
47
+
48
+ Cache = _Cache
49
+ # pyre-ignore[9]: Incompatible variable type: DynamicCache is declared to have
50
+ # type `Optional[Type[DynamicCacheLike]]` but is used as type
51
+ # `Type[_DynamicCache]`
52
+ DynamicCache = _DynamicCache
53
+ except ImportError :
54
+ Cache = DynamicCache = None
55
+ else :
56
+ Cache = DynamicCache = None
57
+
58
+ # GenerationMixin._update_model_kwargs_for_generation
59
+ # "cache_position" at v4.39.0 (only needed for models that support cache class)
60
+ # "use_cache" at v4.41.0 (optional, default is True)
61
+ # "cache_position" is mandatory at v4.43.0 ("use_cache" is still optional, default True)
62
+ _transformers_version : Optional [Version ]
63
+ if transformers_installed :
64
+ _transformers_version = Version (transformers .__version__ )
65
+ else :
66
+ _transformers_version = None
67
+
68
+ _mandated_cache_version = Version ("4.43.0" )
69
+ _use_cache_version = Version ("4.41.0" )
70
+ _cache_position_version = Version ("4.39.0" )
71
+
72
+
73
+ def update_model_kwargs (
74
+ model_kwargs : Dict [str , Any ],
75
+ model : nn .Module ,
76
+ input_ids : torch .Tensor ,
77
+ caching : bool ,
78
+ ) -> None :
79
+ if not supports_caching (model ):
80
+ return
81
+ assert _transformers_version is not None
82
+ if caching :
83
+ # Enable caching
84
+ if _transformers_version >= _cache_position_version :
85
+ cache_position = torch .arange (
86
+ input_ids .shape [1 ], dtype = torch .int64 , device = input_ids .device
87
+ )
88
+ model_kwargs ["cache_position" ] = cache_position
89
+ # pyre-ignore[58]: Unsupported operand `>=` is not supported for operand types
90
+ # `Optional[Version]` and `Version`.
91
+ if _transformers_version >= _use_cache_version :
92
+ model_kwargs ["use_cache" ] = True
93
+ else :
94
+ # Disable caching
95
+ if _transformers_version >= _use_cache_version :
96
+ model_kwargs ["use_cache" ] = False
97
+
30
98
31
- Cache : Optional [Type [CacheLike ]] = _Cache
32
- DynamicCache : Optional [Type [DynamicCacheLike ]] = _DynamicCache
99
+ def supports_caching (model : nn .Module ) -> bool :
100
+ if not transformers_installed :
101
+ # Not a transformers model
102
+ return False
103
+ # Cache may be optional or unsupported depending on model/version
104
+ try :
105
+ # pyre-ignore[21]: Could not find a module corresponding to import
106
+ # `transformers.generation.utils`.
107
+ from transformers .generation .utils import GenerationMixin
108
+ except ImportError :
109
+ return False
110
+ if not isinstance (model , GenerationMixin ):
111
+ # Model isn't a GenerationMixin, we don't support additional caching logic
112
+ # for it
113
+ return False
114
+ assert _transformers_version is not None
115
+ if _transformers_version >= _mandated_cache_version :
116
+ # Cache is mandatory
117
+ return True
118
+ # Fallback on _supports_cache_class attribute
119
+ return getattr (model , "_supports_cache_class" , False )
0 commit comments