Skip to content

Delete deprecated stuff #38838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
60b4cbe
delete deprecated stuff
zucchini-nlp Jun 13, 2025
98184f9
fix copies
zucchini-nlp Jun 16, 2025
325cde8
remove unused tests
zucchini-nlp Jun 16, 2025
26a9c28
fix modernbert and fuyu
zucchini-nlp Jun 17, 2025
68a501e
Merge branch 'main' into remove-deprecations-4.52
zucchini-nlp Jun 20, 2025
2a53654
Update src/transformers/cache_utils.py
zucchini-nlp Jun 23, 2025
1f81714
bye bye `seen_tokens`
zucchini-nlp Jun 23, 2025
ca78f07
address comments
zucchini-nlp Jun 23, 2025
4c9bd33
update typings
zucchini-nlp Jun 23, 2025
fcbd79e
ecnoder decoder models follow same pattern as whisper
zucchini-nlp Jul 1, 2025
00dcc6d
merge main
zucchini-nlp Jul 1, 2025
c8b7099
fix copies
zucchini-nlp Jul 1, 2025
86d470d
why is it set to False?
zucchini-nlp Jul 1, 2025
f19d166
merge main
zucchini-nlp Jul 1, 2025
ab7fac4
fix switch transformers
zucchini-nlp Jul 2, 2025
d9ee03f
fix encoder decoder models shared weight
zucchini-nlp Jul 7, 2025
f06327f
fix copies and RAG
zucchini-nlp Jul 7, 2025
5080a86
Merge branch 'main' into remove-deprecations-4.52
zucchini-nlp Jul 7, 2025
31c5937
remove `next_cache`
zucchini-nlp Jul 8, 2025
cb8be3c
fix gptj/git
zucchini-nlp Jul 8, 2025
7247eed
merge main
zucchini-nlp Jul 8, 2025
48fd132
fix copies
zucchini-nlp Jul 8, 2025
37850e4
fix copies
zucchini-nlp Jul 8, 2025
44e7125
style...
zucchini-nlp Jul 8, 2025
d1f9915
another forgotten docsrting
zucchini-nlp Jul 8, 2025
c769d41
Merge branch 'main' into remove-deprecations-4.52
zucchini-nlp Jul 9, 2025
486b21d
Merge branch 'main' into remove-deprecations-4.52
zucchini-nlp Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/en/cache_explanation.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_stat

2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.

3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token.

The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.

```py
Expand Down
68 changes: 24 additions & 44 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None

def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
Expand Down Expand Up @@ -472,7 +461,6 @@ class DynamicCache(Cache):

def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None:
super().__init__()
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []

Expand Down Expand Up @@ -535,10 +523,6 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -605,7 +589,6 @@ def crop(self, max_length: int):
if self.get_seq_length() <= max_length:
return

self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
if self.key_cache[idx].numel():
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
Expand All @@ -617,7 +600,6 @@ def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCac
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
Expand Down Expand Up @@ -815,10 +797,6 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) < layer_idx:
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
Expand Down Expand Up @@ -857,6 +835,9 @@ class QuantizedCache(DynamicCache):

def __init__(self, cache_config: QuantizedCacheConfig) -> None:
super().__init__()

# Used only for QuantCache where the seq-length can't be inferred easily from cache contents
self._seen_tokens = 0
self._quantized_key_cache: list[torch.Tensor] = []
self._quantized_value_cache: list[torch.Tensor] = []

Expand Down Expand Up @@ -1412,6 +1393,19 @@ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
for layer_idx in range(len(cross_attention_cache.key_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (
self.self_attention_cache.key_cache[layer_idx],
self.self_attention_cache.value_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
self.cross_attention_cache.value_cache[layer_idx],
)

def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
Expand Down Expand Up @@ -2313,10 +2307,6 @@ def __init__(
self._device_key_cache.append(key_cache)
self._device_value_cache.append(value_cache)

# For backwards compatibility.
# TODO(gante): Remove this.
self._seen_tokens = 0

# Create new CUDA stream for parallel prefetching.
self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None

Expand Down Expand Up @@ -2350,10 +2340,6 @@ def update(
value_states = value_states.to(self.value_cache[layer_idx].dtype)

if layer_idx == 0:
# Update seen tokens.
# TODO(gante): Remove this.
self._seen_tokens += key_states.shape[-2]

# Always there.
k_out = self.key_cache[0]
v_out = self.value_cache[0]
Expand Down Expand Up @@ -2407,10 +2393,14 @@ def update(
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""

# TODO(gante): Remove this.
return self._seen_tokens
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # hasn't run a layer with cache after it
or not self.key_cache[layer_idx].numel() # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length

def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
Expand All @@ -2420,22 +2410,12 @@ def get_max_cache_shape(self) -> Optional[int]:
def reset(self) -> None:
"""Resets the cache values while preserving the objects."""

# For backwards compatibility.
# TODO(gante): Remove this.
self._seen_tokens = 0

# Zero out cache.
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address.
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def seen_tokens(self) -> int:
# For backwards compatibility.
# TODO(gante): Remove this.
return self._seen_tokens

def _create_key_value_cache_tensors(
self, shape: tuple[int, ...], device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from ..cache_utils import DynamicCache, HybridCache, StaticCache
from ..cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, StaticCache
from ..generation.configuration_utils import GenerationConfig
from ..masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
Expand Down Expand Up @@ -548,14 +548,15 @@ def __init__(self, model, max_static_cache_length, batch_size):
self.lm_head = model.lm_head
self.config = model.config

# Initialize static cache
# Initialize static cache for decoder and DynamicCache for encoder
self.static_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=max_static_cache_length,
device="cpu",
dtype=torch.float32,
)
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This export recipe only uses the decoder part so it should not need this change no?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is encoder-decoder and needs cache for both. Prev, we would wrap static cache in EncoderDecoderCache in model's forward, but from this PR we don't do any legacy/hacks. Instead we expect users to pass the correct new cache class

Export doesn't really care about encoder cache indeed, it's needed for the generation code to be working

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the decoder only part of the model also expects a EncoderDecoderCache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because it does cross attention and still would be looking for cache to store encoder_hidden_states. Most of yhese models actually have no option to be run as "CausalLM" type with no cross attention as per code (e.g. T5 had failing tests)


# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
Expand All @@ -567,7 +568,7 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
Expand Down
Loading