Skip to content

Delete deprecated stuff #38838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jul 10, 2025

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

As per title. Removes

  • deprecated legacy cache
  • deprecation we has for new processor API
  • **rope_kwargs from the RoPE API
  • _seen_tokens in cache classes

First review @gante as most modifications are around cache/generation

@zucchini-nlp zucchini-nlp requested a review from gante June 16, 2025 06:58
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yay cleanups! 🧹 🧹

I left a few minor comments to address. I also have very low confidence on the processing side of changes, so another reviewer for those parts would be great 🤗

@@ -93,7 +93,6 @@ def _compute_default_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm not 100% sure we won't get custom-model-on-the-hub-related BC issues.

On one hand, _compute_default_rope_parameters is a post-refactor function, so if a user started their model from one of our models, they wouldn't have used rope_kwargs. On the other hand, we didn't add an explicit deprecation cycle, my bad 😢 It's a very low likelihood, so I'm happy with the deletion, I don't think adding a deprecation cycle would be worth the extra work)

@@ -315,7 +315,8 @@ def forward(
query_states = query_states.view(*q_input_shape)
query_states = query_states.transpose(1, 2).contiguous()

if past_key_value is not None:
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the diff here, it simplified things, and I see parts of it are already present in bart.

Two notes:

  • The type hints/docstrings need updates. e.g. On L838 we say past_key_values can only be an EncoderDecoderCache (we can accept any Cache) or a tuple(tuple(torch.FloatTensor) (no longer supported). This comment possibly applies on all models on this diff, and even beyond this diff 👀 Given the wide extent of changes, happy to leave it to a follow-up PR.
  • @vasqu has been making changes to whisper recently, so we should review this file as well :D

past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
if self.is_decoder and use_cache and past_key_values is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all encoder-decoder models will have to copy the pattern from whisper (if it's a decoder-only model, instantiate a DynamicCache) to preserve BC -- we often allow configuring encoder-decoder models as decoder-only, when config.is_encoder_decoder=False

@@ -583,7 +583,7 @@ def generate(
logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
"Using processors without these attributes in the config is deprecated and will throw an error in v4.54."
Copy link
Member Author

@zucchini-nlp zucchini-nlp Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depends on #35560, which became stale. For core maintainer's attention, awaits review so we can remove the warnings :)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🧹 🧹 🧹

@julien-c
Copy link
Member

julien-c commented Jul 7, 2025

nice, love those PRs! 🎉

@zucchini-nlp
Copy link
Member Author

run-slow: align, aria, aya_vision, bart, bigbird_pegasus, blenderbot, blenderbot_small, blip_2, bloom, chameleon, codegen, colpali, colqwen2, csm, dbrx, falcon

Copy link
Contributor

github-actions bot commented Jul 7, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/align', 'models/aria', 'models/aya_vision', 'models/bart', 'models/bigbird_pegasus', 'models/blenderbot', 'models/blenderbot_small', 'models/blip_2', 'models/bloom', 'models/chameleon', 'models/codegen', 'models/colpali', 'models/colqwen2', 'models/csm', 'models/dbrx', 'models/falcon']
quantizations: [] ...

@zucchini-nlp
Copy link
Member Author

Failing slow tests look to be same as the failing test in main branch. @Cyrilvallez can you take one last look and then I'll merge?

@zucchini-nlp zucchini-nlp requested a review from Cyrilvallez July 7, 2025 09:32
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, this is god's work 🤗❤️
My main comment is for old models that use stuff like next_cache = output[-1] on each iteration of the layers, can we remove it everywhere? We force cache classes anyway now, so it does not make any sense. We should even stop returning the cache in the Attention and Layer, to simplify much further (it's a bit breaking, but aligned with what we did already!)

self.static_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=max_static_cache_length,
device="cpu",
dtype=torch.float32,
)
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This export recipe only uses the decoder part so it should not need this change no?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is encoder-decoder and needs cache for both. Prev, we would wrap static cache in EncoderDecoderCache in model's forward, but from this PR we don't do any legacy/hacks. Instead we expect users to pass the correct new cache class

Export doesn't really care about encoder cache indeed, it's needed for the generation code to be working

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the decoder only part of the model also expects a EncoderDecoderCache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because it does cross attention and still would be looking for cache to store encoder_hidden_states. Most of yhese models actually have no option to be run as "CausalLM" type with no cross attention as per code (e.g. T5 had failing tests)

@zucchini-nlp
Copy link
Member Author

run-slow: align, aria, aya_vision, bart, bigbird_pegasus, blenderbot, blenderbot_small, blip_2, bloom, chameleon, codegen, colpali, colqwen2, csm, dbrx, falcon

Copy link
Contributor

github-actions bot commented Jul 8, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/align', 'models/aria', 'models/aya_vision', 'models/bart', 'models/bigbird_pegasus', 'models/blenderbot', 'models/blenderbot_small', 'models/blip_2', 'models/bloom', 'models/chameleon', 'models/codegen', 'models/colpali', 'models/colqwen2', 'models/csm', 'models/dbrx', 'models/falcon']
quantizations: [] ...

@zucchini-nlp zucchini-nlp enabled auto-merge (squash) July 10, 2025 05:06
@zucchini-nlp
Copy link
Member Author

Merging, looks like there are no more questions left. It will unblock mt for another clean-up PR

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: align, aria, aya_vision, bart, bigbird_pegasus, blenderbot, blenderbot_small, blip_2, bloom, chameleon, codegen, colpali, colqwen2, csm, dbrx, falcon

@zucchini-nlp zucchini-nlp merged commit bc161d5 into huggingface:main Jul 10, 2025
25 checks passed
rjgleaton pushed a commit to rjgleaton/transformers that referenced this pull request Jul 17, 2025
* delete deprecated stuff

* fix copies

* remove unused tests

* fix modernbert and fuyu

* Update src/transformers/cache_utils.py

Co-authored-by: Joao Gante <[email protected]>

* bye bye `seen_tokens`

* address comments

* update typings

* ecnoder decoder models follow same pattern as whisper

* fix copies

* why is it set to False?

* fix switch transformers

* fix encoder decoder models shared weight

* fix copies and RAG

* remove `next_cache`

* fix gptj/git

* fix copies

* fix copies

* style...

* another forgotten docsrting

---------

Co-authored-by: Joao Gante <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants