Skip to content

Commit 1a40d4a

Browse files
zucchini-nlpgante
authored andcommitted
Rename supports_static_cache to can_compile_fullgraph (huggingface#39505)
* update all * Apply suggestions from code review Co-authored-by: Joao Gante <[email protected]> * apply suggestions * fix copies --------- Co-authored-by: Joao Gante <[email protected]>
1 parent 00d2ddf commit 1a40d4a

File tree

128 files changed

+141
-146
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+141
-146
lines changed

examples/modular-transformers/modeling_my_new_model2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
294294
_supports_flex_attn = True
295295
_supports_cache_class = True
296296
_supports_quantized_cache = True
297-
_supports_static_cache = True
297+
_can_compile_fullgraph = True
298298
_supports_attention_backend = True
299299
_can_record_outputs = {
300300
"hidden_states": MyNewModel2DecoderLayer,

examples/modular-transformers/modeling_new_task_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
9494
_skip_keys_device_placement = "past_key_values"
9595
_supports_cache_class = True
9696
_supports_quantized_cache = True
97-
_supports_static_cache = True
97+
_can_compile_fullgraph = True
9898
_supports_flash_attn = True
9999
_supports_sdpa = True
100100
_supports_flex_attn = True

examples/modular-transformers/modeling_super.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class SuperPreTrainedModel(PreTrainedModel):
293293
_supports_flex_attn = True
294294
_supports_cache_class = True
295295
_supports_quantized_cache = True
296-
_supports_static_cache = True
296+
_can_compile_fullgraph = True
297297
_supports_attention_backend = True
298298
_can_record_outputs = {
299299
"hidden_states": SuperDecoderLayer,

src/transformers/generation/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,7 +2059,7 @@ def _prepare_cache_for_generation(
20592059
)
20602060
if generation_config.cache_implementation is not None:
20612061
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
2062-
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
2062+
if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph:
20632063
raise ValueError(
20642064
"This model does not support `cache_implementation='static'`. Please check the following "
20652065
"issue: https://github.com/huggingface/transformers/issues/28981"
@@ -2215,7 +2215,8 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge
22152215
using_compilable_cache = (
22162216
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
22172217
)
2218-
can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache
2218+
# TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile)
2219+
can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph
22192220

22202221
# Exception 1: Some quantization methods do not support compilation
22212222
if getattr(self, "hf_quantizer", None) is not None:

src/transformers/modeling_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,8 +2063,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
20632063
# Flex Attention support
20642064
_supports_flex_attn = False
20652065

2066-
# Has support `torch.compile(fullgraph=True)`
2067-
_supports_static_cache = False
2066+
_can_compile_fullgraph = False
20682067

20692068
# A tensor parallel plan to be applied to the model when TP is enabled. For
20702069
# top-level models, this attribute is currently defined in respective model

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class ArceePreTrainedModel(PreTrainedModel):
313313
_supports_sdpa = True
314314
_supports_flex_attn = True
315315

316-
_supports_static_cache = True
316+
_can_compile_fullgraph = True
317317
_supports_attention_backend = True
318318
_can_record_outputs = {
319319
"hidden_states": ArceeDecoderLayer,

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ class AriaPreTrainedModel(PreTrainedModel):
654654
_supports_flash_attn = True
655655
_supports_sdpa = True
656656
_supports_flex_attn = True
657-
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
657+
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
658658
_supports_attention_backend = True
659659
_can_record_outputs = {
660660
"hidden_states": AriaTextDecoderLayer,

src/transformers/models/aria/modular_aria.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ def _init_weights(self, module):
13021302
class AriaPreTrainedModel(LlamaPreTrainedModel):
13031303
config: AriaConfig
13041304
base_model_prefix = ""
1305-
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
1305+
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
13061306
_supports_attention_backend = True
13071307

13081308
def _init_weights(self, module):

src/transformers/models/aya_vision/modeling_aya_vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
9696

9797
_supports_flash_attn = True
9898
_supports_sdpa = True
99-
_supports_static_cache = False
99+
_can_compile_fullgraph = False
100100
_supports_flex_attn = True
101101
_supports_attention_backend = True
102102

src/transformers/models/aya_vision/modular_aya_vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def pixel_shuffle(self, image_features): # B, S, D
9090

9191

9292
class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
93-
_supports_static_cache = False
93+
_can_compile_fullgraph = False
9494

9595

9696
class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):

0 commit comments

Comments
 (0)