diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 4190cefdb8a1..6c31035234bb 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -99,8 +99,6 @@ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_stat 2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token. -3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token. - The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token. ```py diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bfba7e2bbda3..0b7052c63485 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -185,17 +185,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor): device = self.value_cache[layer_idx].device self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for @@ -472,7 +461,6 @@ class DynamicCache(Cache): def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: super().__init__() - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] @@ -535,10 +523,6 @@ def update( Return: A tuple containing the updated key and value states. """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - # Update the cache if key_states is not None: if len(self.key_cache) <= layer_idx: @@ -605,7 +589,6 @@ def crop(self, max_length: int): if self.get_seq_length() <= max_length: return - self._seen_tokens = max_length for idx in range(len(self.key_cache)): if self.key_cache[idx].numel(): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] @@ -617,7 +600,6 @@ def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCac out = [] for i in range(0, full_batch_size, split_size): current_split = DynamicCache() - current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] out.append(current_split) @@ -815,10 +797,6 @@ def update( Return: A tuple containing the updated key and value states. """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - # Update the cache if len(self.key_cache) < layer_idx: raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") @@ -857,6 +835,9 @@ class QuantizedCache(DynamicCache): def __init__(self, cache_config: QuantizedCacheConfig) -> None: super().__init__() + + # Used only for QuantCache where the seq-length can't be inferred easily from cache contents + self._seen_tokens = 0 self._quantized_key_cache: list[torch.Tensor] = [] self._quantized_value_cache: list[torch.Tensor] = [] @@ -1412,6 +1393,19 @@ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): for layer_idx in range(len(cross_attention_cache.key_cache)): self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the @@ -2313,10 +2307,6 @@ def __init__( self._device_key_cache.append(key_cache) self._device_value_cache.append(value_cache) - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 - # Create new CUDA stream for parallel prefetching. self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None @@ -2350,10 +2340,6 @@ def update( value_states = value_states.to(self.value_cache[layer_idx].dtype) if layer_idx == 0: - # Update seen tokens. - # TODO(gante): Remove this. - self._seen_tokens += key_states.shape[-2] - # Always there. k_out = self.key_cache[0] v_out = self.value_cache[0] @@ -2407,10 +2393,14 @@ def update( return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - - # TODO(gante): Remove this. - return self._seen_tokens + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length def get_max_cache_shape(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" @@ -2420,22 +2410,12 @@ def get_max_cache_shape(self) -> Optional[int]: def reset(self) -> None: """Resets the cache values while preserving the objects.""" - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 - # Zero out cache. for layer_idx in range(len(self.key_cache)): # In-place ops prevent breaking the static address. self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() - @property - def seen_tokens(self) -> int: - # For backwards compatibility. - # TODO(gante): Remove this. - return self._seen_tokens - def _create_key_value_cache_tensors( self, shape: tuple[int, ...], device: torch.device ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0df283c83b71..41a30c1374aa 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -15,7 +15,7 @@ import torch -from ..cache_utils import DynamicCache, HybridCache, StaticCache +from ..cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, StaticCache from ..generation.configuration_utils import GenerationConfig from ..masking_utils import ( ALL_MASK_ATTENTION_FUNCTIONS, @@ -548,7 +548,7 @@ def __init__(self, model, max_static_cache_length, batch_size): self.lm_head = model.lm_head self.config = model.config - # Initialize static cache + # Initialize static cache for decoder and DynamicCache for encoder self.static_cache = StaticCache( config=self.config, max_batch_size=batch_size, @@ -556,6 +556,7 @@ def __init__(self, model, max_static_cache_length, batch_size): device="cpu", dtype=torch.float32, ) + self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) # Register cache buffers to make them exportable for i in range(len(self.static_cache.key_cache)): @@ -567,7 +568,7 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): outputs = self.decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_hidden_states, - past_key_values=self.static_cache, + past_key_values=self.cache, use_cache=True, cache_position=cache_position, ) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 4786cce27356..59989aa5927c 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -93,7 +93,6 @@ def _compute_default_rope_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, - **rope_kwargs, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation @@ -104,25 +103,14 @@ def _compute_default_rope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -135,7 +123,6 @@ def _compute_linear_scaling_rope_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, - **rope_kwargs, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev @@ -146,24 +133,14 @@ def _compute_linear_scaling_rope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - factor = rope_kwargs["factor"] - elif config is not None: - factor = config.rope_scaling["factor"] + factor = config.rope_scaling["factor"] # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len) # Then applies linear scaling to the frequencies. # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so @@ -176,7 +153,6 @@ def _compute_dynamic_ntk_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, - **rope_kwargs, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla @@ -187,30 +163,17 @@ def _compute_dynamic_ntk_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length, used to update the dynamic RoPE at inference time. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - max_position_embeddings = rope_kwargs["max_position_embeddings"] - factor = rope_kwargs["factor"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - max_position_embeddings = config.max_position_embeddings - factor = config.rope_scaling["factor"] + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] attention_factor = 1.0 # Unused in this type of RoPE @@ -232,7 +195,7 @@ def _compute_dynamic_ntk_parameters( def _compute_yarn_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Please refer to the @@ -244,17 +207,10 @@ def _compute_yarn_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ - # No need to keep BC with yarn, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" - ) base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 @@ -328,7 +284,7 @@ def linear_ramp_factor(min, max, dim): def _compute_longrope_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with LongRoPE scaling. Please refer to the @@ -340,20 +296,11 @@ def _compute_longrope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - # No need to keep BC with longrope, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " - f"{rope_kwargs}" - ) - base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -391,7 +338,7 @@ def _compute_longrope_parameters( def _compute_llama3_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies for llama 3.1. @@ -403,14 +350,12 @@ def _compute_llama3_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len) factor = config.rope_scaling["factor"] # `8` in the original implementation low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation diff --git a/src/transformers/models/align/processing_align.py b/src/transformers/models/align/processing_align.py index 33aaca273262..dfc98c4405b4 100644 --- a/src/transformers/models/align/processing_align.py +++ b/src/transformers/models/align/processing_align.py @@ -19,7 +19,7 @@ from typing import Union from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput @@ -110,8 +110,6 @@ def __call__( """ if text is None and images is None: raise ValueError("You must specify either text or images.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( AlignProcessorKwargs, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c3211beb6f5e..386346ae2a63 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -920,7 +920,7 @@ class AriaCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -947,7 +947,7 @@ class AriaCausalLMOutputWithPast(ModelOutput): ) class AriaModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -1034,7 +1034,7 @@ def forward( pixel_mask: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1192,7 +1192,7 @@ def forward( pixel_mask: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index e2f1c9742150..aeebd1dd09d1 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -18,6 +18,7 @@ import numpy as np from ...activations import ACT2FN +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format @@ -1431,7 +1432,7 @@ def forward( pixel_mask: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1528,7 +1529,7 @@ def forward( pixel_mask: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 7b6e8efb1f26..599bc63b1156 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput @@ -129,7 +130,7 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -156,7 +157,7 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput): ) class AyaVisionModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -261,7 +262,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -413,7 +414,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 247b6af446d3..a533ca32daa4 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -29,6 +29,7 @@ ) from ...activations import ACT2FN +from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging @@ -181,7 +182,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -267,7 +268,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index ba7e8603f214..7d768a734827 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -949,7 +949,7 @@ def forward( encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -998,7 +998,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1230,7 +1230,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1402,7 +1402,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1901,7 +1901,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 465b94e13bee..1e126fbcaff8 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2107,7 +2107,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -2156,7 +2156,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -2381,7 +2381,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -2543,7 +2543,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 9d31049a1715..659b856a77c1 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -953,7 +953,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1186,7 +1186,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1361,7 +1361,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1551,7 +1551,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 930233413769..f82a0d322283 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -934,7 +934,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1153,7 +1153,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1315,7 +1315,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1505,7 +1505,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 80615fd9b940..6a1866ffcbb3 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -774,7 +775,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.attention( @@ -1146,7 +1147,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 414426844513..0360bdf6f724 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -328,12 +328,7 @@ def forward( output_tensor = self.dense(context_layer) output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - - outputs = (output_tensor, layer_past) - if output_attentions: - outputs += (attention_probs,) - - return outputs + return output_tensor, attention_probs class BloomMLP(nn.Module): @@ -405,7 +400,7 @@ def forward( residual = hidden_states # Self attention. - attn_outputs = self.self_attention( + attention_output, attn_weights = self.self_attention( layernorm_output, residual, layer_past=layer_past, @@ -417,10 +412,6 @@ def forward( cache_position=cache_position, ) - attention_output = attn_outputs[0] - - outputs = attn_outputs[1:] - layernorm_output = self.post_attention_layernorm(attention_output) # Get residual @@ -432,12 +423,7 @@ def forward( # MLP. output = self.mlp(layernorm_output, residual) - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, past_kv, attentions + return output, attn_weights # hidden_states, attentions @auto_docstring @@ -560,19 +546,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() batch_size, seq_length, _ = inputs_embeds.shape past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -587,7 +566,6 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) hidden_states = self.word_embeddings_layernorm(inputs_embeds) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -618,11 +596,8 @@ def forward( ) hidden_states = outputs[0] - if use_cache: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -630,18 +605,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index f8dfcff33cf3..a7864e28bad8 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -28,7 +28,6 @@ ProcessorMixin, TextKwargs, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput @@ -129,8 +128,7 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) + if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index d00528b14069..f39364180171 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -217,12 +217,7 @@ def forward( attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights # Copied from transformers.models.gptj.modeling_gptj.GPTJMLP with GPTJ->CodeGen @@ -268,7 +263,7 @@ def forward( ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, @@ -278,18 +273,10 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] - feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + hidden_states = attn_outputs + feed_forward_hidden_states + residual - return outputs # hidden_states, present, (attentions) + return hidden_states, attn_weights @auto_docstring @@ -390,19 +377,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() seq_length = inputs_embeds.shape[1] if cache_position is None: @@ -431,7 +411,6 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = (-1, seq_length, hidden_states.size(-1)) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, block in enumerate(self.h): @@ -450,11 +429,8 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) hidden_states = self.ln_f(hidden_states) @@ -463,18 +439,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index a5c8f8dd6daf..0ac554ad2eae 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -63,7 +63,7 @@ class ColPaliForRetrievalOutput(ModelOutput): Language modeling loss (for next-token prediction). embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The embeddings of the model. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 3d259212c509..65bfa7d9214f 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -75,7 +75,7 @@ class ColQwen2ForRetrievalOutput(ModelOutput): Language modeling loss (for next-token prediction). embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The embeddings of the model. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -130,7 +130,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, labels: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index 881cb8af662e..08b79e247e68 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -243,7 +243,7 @@ class ColQwen2ForRetrievalOutput(ModelOutput): Language modeling loss (for next-token prediction). embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The embeddings of the model. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -280,7 +280,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, labels: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 3d2690568058..7c18216eeb34 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -58,7 +58,7 @@ class CsmOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 15e7a4868011..f6098993c387 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -58,7 +58,7 @@ class CsmOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 1dcf34a91c3e..84c72d9ac2c4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -305,7 +305,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class DbrxFlashAttention2(DbrxAttention): @@ -430,7 +430,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class DbrxSdpaAttention(DbrxAttention): @@ -525,7 +525,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None DBRX_ATTENTION_CLASSES = { @@ -561,7 +561,7 @@ def forward( residual_states = hidden_states hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype) - hidden_states, attn_weights, past_key_value = self.attn( + hidden_states, attn_weights = self.attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -578,7 +578,7 @@ def forward( residual_states = hidden_states hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype) - return residual_states, hidden_states, attn_weights, past_key_value + return residual_states, hidden_states, attn_weights class DbrxRouter(nn.Module): @@ -775,7 +775,7 @@ def forward( """ # Norm + Attention + Norm - resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm( + resid_states, hidden_states, self_attn_weights = self.norm_attn_norm( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -796,9 +796,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -909,19 +906,12 @@ def forward( inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -943,7 +933,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for block in self.blocks: if output_hidden_states: @@ -962,9 +951,6 @@ def forward( hidden_states = block_outputs[0] - if use_cache: - next_decoder_cache = block_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (block_outputs[1],) @@ -977,19 +963,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_router_logits] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 78792318a16a..ebcdf42086a3 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -226,10 +226,6 @@ def __init__(self, config: FalconConfig, layer_idx=None): self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 - # TODO (raushan): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - if config.rotary: - self.rotary_emb = FalconRotaryEmbedding(config=self.config) - def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv` @@ -362,10 +358,7 @@ def forward( attn_output = self.dense(attn_output) - if output_attentions: - return attn_output, layer_past, attention_scores - else: - return attn_output, layer_past + return attn_output, attention_scores else: if self._use_sdpa and not output_attentions and head_mask is None: @@ -380,6 +373,7 @@ def forward( dropout_p=self.attention_dropout.p if self.training else 0.0, is_causal=is_causal, ) + attention_probs = None attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) @@ -416,10 +410,7 @@ def forward( attn_output = self.dense(attn_output) - if output_attentions: - return attn_output, layer_past, attention_probs - else: - return attn_output, layer_past + return attn_output, attention_probs class FalconFlashAttention2(FalconAttention): @@ -528,7 +519,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, layer_past, attn_weights + return attn_output, attn_weights class FalconMLP(nn.Module): @@ -603,7 +594,7 @@ def forward( attention_layernorm_out = self.input_layernorm(hidden_states) # Self attention. - attn_outputs = self.self_attention( + attention_output, attn_weights = self.self_attention( attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, @@ -616,8 +607,6 @@ def forward( position_embeddings=position_embeddings, ) - attention_output = attn_outputs[0] - if not self.config.new_decoder_architecture: if self.config.parallel_attn: mlp_layernorm_out = attention_layernorm_out @@ -634,8 +623,6 @@ def forward( ): mlp_layernorm_out = attention_layernorm_out - outputs = attn_outputs[1:] - # MLP. mlp_output = self.mlp(mlp_layernorm_out) @@ -644,12 +631,7 @@ def forward( output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, past_kv, attentions + return output, attn_weights @auto_docstring @@ -777,19 +759,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() # Compute alibi tensor: check build_alibi_tensor documentation alibi = None @@ -827,7 +802,6 @@ def forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -849,11 +823,8 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -861,18 +832,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 71c06ce6b356..c75644036b5c 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -91,8 +91,6 @@ def __init__( self.has_previous_state = False self.conv_kernel_size = config.mamba_d_conv - self._seen_tokens = 0 - self.intermediate_size = ( config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) ) @@ -149,10 +147,6 @@ def update( Return: A tuple containing the updated key and value states. """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - # Update the cache if len(self.key_cache) <= layer_idx: # There may be skipped layers, fill them with empty lists diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 8bfbe328a262..13146e7bd1f9 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -99,8 +99,6 @@ def __init__( self.has_previous_state = False self.conv_kernel_size = config.mamba_d_conv - self._seen_tokens = 0 - self.intermediate_size = ( config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) ) @@ -157,10 +155,6 @@ def update( Return: A tuple containing the updated key and value states. """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - # Update the cache if len(self.key_cache) <= layer_idx: # There may be skipped layers, fill them with empty lists diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 8c445571ac34..589620ff804c 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from torch import nn +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -41,6 +42,7 @@ class FuyuPreTrainedModel(PreTrainedModel): _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True + _supports_cache_class = True _no_split_modules = [] _skip_keys_device_placement = "past_key_values" @@ -155,7 +157,7 @@ def forward( image_patches_indices: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -187,15 +189,9 @@ def forward( else: raise ValueError("You have to specify either input_is or inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) @@ -281,7 +277,7 @@ def forward( image_patches_indices: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, labels: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index ee4b411c3ec1..f13869a146db 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -27,7 +27,6 @@ ProcessingKwargs, ProcessorMixin, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_torch_available, logging, requires_backends @@ -522,8 +521,6 @@ def __call__( # --- Check input validity --- if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be None.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( FuyuProcessorKwargs, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index b8f4a88da3a6..d65aed200c74 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -63,7 +63,7 @@ ) class Gemma3ModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -89,7 +89,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index a501d03a7c1a..ad52c63d3d38 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -227,10 +227,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput @@ -290,7 +287,7 @@ def forward( output_attentions: Optional[bool] = False, pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: - self_outputs = self.self( + attn_output, self_attn_weights = self.self( hidden_states, attention_mask, head_mask, @@ -298,9 +295,8 @@ def forward( output_attentions, pixel_values_present, ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs + attention_output = self.output(attn_output, hidden_states) + return attention_output, self_attn_weights # Copied from transformers.models.bert.modeling_bert.BertIntermediate @@ -353,7 +349,7 @@ def forward( pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attention_outputs = self.attention( + attention_output, self_attention_weights = self.attention( hidden_states, attention_mask, head_mask, @@ -361,21 +357,11 @@ def forward( past_key_value=past_key_value, pixel_values_present=pixel_values_present, ) - attention_output = self_attention_outputs[0] - - # if decoder, the last output is tuple of self-attn cache - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - outputs = outputs + (present_key_value,) - - return outputs + return layer_output, self_attention_weights def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -409,23 +395,15 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - next_decoder_cache = None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -442,24 +420,18 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_self_attentions, ] @@ -467,7 +439,7 @@ def forward( ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 51ac8c2de588..0980b81b55ad 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -20,7 +20,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging @@ -98,9 +98,6 @@ def __call__( if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) - output_kwargs = self._merge_kwargs( GitProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index fcf36a826099..86bf2b56ffea 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -801,7 +801,7 @@ def forward( ) class Glm4vModelOutputWithPast(ModelOutput): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -1364,7 +1364,7 @@ class Glm4vCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index d5cdeff2f33b..2aec3cf284a8 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -31,6 +31,7 @@ from transformers.utils.generic import check_model_inputs from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -498,7 +499,7 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -525,7 +526,7 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput): ) class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -590,7 +591,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -727,7 +728,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 92cf6aab4444..1ae6880753cf 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -35,6 +35,7 @@ SamVisionLayer, ) +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack @@ -330,7 +331,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -403,7 +404,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 688f368015e9..9d606297e829 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -258,11 +258,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, past_kv, (attentions) + return attn_output, attn_weights class GPTNeoFlashAttention2(GPTNeoSelfAttention): @@ -364,11 +360,7 @@ def forward( attn_output = self.out_proj(attn_weights_reshaped) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights_reshaped,) - - return outputs + return attn_output, attn_weights_reshaped GPT_NEO_ATTENTION_CLASSES = { @@ -454,7 +446,7 @@ def forward( ): residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_output, attn_weights = self.attn( hidden_states, layer_past=layer_past, attention_mask=attention_mask, @@ -463,8 +455,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] + # residual connection hidden_states = attn_output + residual @@ -474,12 +465,7 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, past_kv, attentions + return hidden_states, attn_weights @auto_docstring @@ -588,19 +574,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() seq_length = inputs_embeds.shape[1] if cache_position is None: @@ -630,7 +609,6 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = (-1, seq_length, hidden_states.size(-1)) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, block in enumerate(self.h): @@ -648,11 +626,8 @@ def forward( ) hidden_states = outputs[0] - if use_cache: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) hidden_states = self.ln_f(hidden_states) @@ -661,18 +636,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 35dc1963d86f..599e7ee76f1c 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -446,6 +446,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_in(input_ids) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + if use_cache and past_key_values is None: past_key_values = DynamicCache() diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 0ec7e9db6249..704b69aa5dcf 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -320,6 +320,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_in(input_ids) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + if use_cache and past_key_values is None: past_key_values = DynamicCache() diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index f1866b497778..30a2ce2fbc5f 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -150,11 +150,7 @@ def forward( attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) attn_output = self.dense(attn_output) - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs, self.dense_bias + return attn_output, attn_weights, self.dense_bias @classmethod def _split_heads(cls, tensor, num_attention_heads, attn_head_size): @@ -357,7 +353,7 @@ def forward( ): residual = hidden_states ln_out = self.input_layernorm(hidden_states) - attention_layer_outputs, attn_bias = self.attention( + attn_output, attn_weights, attn_bias = self.attention( ln_out, attention_mask=attention_mask, layer_past=layer_past, @@ -368,8 +364,6 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, ) - attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions) - outputs = attention_layer_outputs[1:] # attn_output = (atten_output + bias) + residual attn_output = bias_dropout_add( @@ -386,12 +380,7 @@ def forward( mlp_output, bias=None, residual=attn_output, prob=self.hidden_dropout, training=self.training ) - if use_cache: - outputs = (attn_output,) + outputs - else: - outputs = (attn_output,) + outputs[1:] - - return outputs # hidden_states, present, (attentions) + return attn_output, attn_weights @auto_docstring @@ -460,19 +449,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_in(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() seq_length = inputs_embeds.shape[1] if cache_position is None: @@ -497,7 +479,6 @@ def forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - next_decoder_cache = None all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, layer in enumerate(self.layers): @@ -516,26 +497,22 @@ def forward( position_embeddings=position_embeddings, ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] if output_attentions: - all_attentions = all_attentions + (outputs[2 if use_cache else 1],) + all_attentions = all_attentions + (outputs[1],) hidden_states = self.final_layer_norm(hidden_states) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a8faccb5d80c..0aa03b559b95 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -254,11 +254,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class GPTJFlashAttention2(GPTJAttention): @@ -402,12 +398,7 @@ def forward( ) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs + return attn_output, attn_weights GPTJ_ATTENTION_CLASSES = { @@ -456,7 +447,7 @@ def forward( ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, @@ -466,18 +457,10 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] - feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + hidden_states = attn_outputs + feed_forward_hidden_states + residual - return outputs # hidden_states, present, (attentions) + return hidden_states, attn_weights @auto_docstring @@ -676,19 +659,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() seq_length = inputs_embeds.shape[1] if cache_position is None: @@ -719,7 +695,6 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = (-1, seq_length, hidden_states.size(-1)) - next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, block in enumerate(self.h): @@ -752,11 +727,8 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - next_decoder_cache = outputs[1] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -771,18 +743,14 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 7a4ea8aa23d1..2b0efc519e66 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from torch import nn +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel @@ -44,7 +45,7 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -361,7 +362,7 @@ def forward( input_features_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index c78c97ade3b9..065dfca74b6f 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -464,7 +464,7 @@ def forward( attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights def eager_attention_forward( @@ -550,7 +550,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -576,9 +576,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -683,14 +680,12 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -716,7 +711,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -736,9 +730,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -751,15 +742,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 4160dd82e7e4..8139103d210c 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -219,7 +219,7 @@ def forward( attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer @@ -1118,7 +1118,7 @@ def forward( # No attention weights for state space layers self_attn_weights = None else: - hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, @@ -1149,9 +1149,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (past_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -1335,7 +1332,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) @@ -1357,9 +1353,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: if layer_outputs[1] is not None: # append attentions only of attention layers. Mamba layers return `None` as the attention weights @@ -1379,11 +1372,9 @@ def forward( if past_key_values and not past_key_values.has_previous_state: past_key_values.has_previous_state = True - next_cache = next_decoder_cache if use_cache else None - return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, @@ -1786,7 +1777,7 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - else: + elif use_cache: past_key_values = HybridMambaAttentionDynamicCache( self.config, input_ids.shape[0], self.dtype, device=self.device ) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index fb49cf29b37a..f894d1a8f8a7 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -124,7 +124,7 @@ def forward( # No attention weights for state space layers self_attn_weights = None else: - hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, @@ -155,9 +155,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (past_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -260,7 +257,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) @@ -282,9 +278,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: if layer_outputs[1] is not None: # append attentions only of attention layers. Mamba layers return `None` as the attention weights @@ -304,11 +297,9 @@ def forward( if past_key_values and not past_key_values.has_previous_state: past_key_values.has_previous_state = True - next_cache = next_decoder_cache if use_cache else None - return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, @@ -363,7 +354,7 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - else: + elif use_cache: past_key_values = HybridMambaAttentionDynamicCache( self.config, input_ids.shape[0], self.dtype, device=self.device ) diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 5e10ed2552f7..475e97dc84d7 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -403,7 +403,7 @@ def forward( attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer): @@ -463,7 +463,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -494,9 +494,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -635,14 +632,12 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -668,7 +663,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -688,9 +682,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -703,15 +694,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, diff --git a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py index d8e644b7b982..ee8fc48d38e2 100644 --- a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py @@ -107,7 +107,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -140,9 +140,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 706584ccf3ba..dd05709e0805 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -64,7 +64,7 @@ class IdeficsBaseModelOutputWithPast(ModelOutput): If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, @@ -99,7 +99,7 @@ class IdeficsCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -647,7 +647,7 @@ def forward( if output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # this was adapted from LlamaDecoderLayer @@ -701,7 +701,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -726,9 +726,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -847,7 +844,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states, self_attn_weights = self.cross_attn( hidden_states=hidden_states, key_value_states=image_hidden_states, attention_mask=image_attention_mask, @@ -871,9 +868,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1011,7 +1005,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_encoder_embeddings: Optional[torch.FloatTensor] = None, @@ -1055,19 +1049,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() batch_size, seq_length, _ = inputs_embeds.shape past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1160,7 +1147,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1194,9 +1180,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1206,14 +1189,11 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size) return IdeficsBaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, image_hidden_states=image_hidden_states, @@ -1409,7 +1389,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_encoder_embeddings: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 2ff60ac4e326..3c59105c2360 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -27,7 +27,6 @@ ProcessorMixin, TextKwargs, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_tf_available, is_torch_available @@ -340,8 +339,6 @@ def __call__( """ if images is None and text is None: raise ValueError("You need to specify either `text` or `images` and `text`.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) if images is None: # assuming the user wants to use the old behavior with prompts as the only argument diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 50d1489c2d77..06fccf9614c1 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -50,7 +50,7 @@ class Idefics2BaseModelOutputWithPast(ModelOutput): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, @@ -84,7 +84,7 @@ class Idefics2CausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see @@ -639,7 +639,7 @@ def forward( context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -728,7 +728,7 @@ def forward( context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, @@ -1012,7 +1012,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -1051,7 +1051,11 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if use_cache and not isinstance(past_key_values, Cache): + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: past_key_values = DynamicCache() if inputs_embeds is None: @@ -1155,7 +1159,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index 9998a4b0108a..fedc24eb0e61 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -26,7 +26,6 @@ ProcessingKwargs, ProcessorMixin, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import AddedToken, TextInput from ...utils import logging @@ -181,8 +180,6 @@ def __call__( """ if text is None and images is None: raise ValueError("You must provide either `text` or `images`.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( Idefics2ProcessorKwargs, diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 14381c6d68b8..c018963943da 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -22,7 +22,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -50,7 +50,7 @@ class Idefics3BaseModelOutputWithPast(ModelOutput): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, @@ -83,7 +83,7 @@ class Idefics3CausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see @@ -739,7 +739,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -889,7 +889,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index d076a193162c..d0372118b958 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -16,7 +16,6 @@ import math import os -import warnings from typing import Any, Optional, Union import torch @@ -597,19 +596,6 @@ def forward( >>> last_hidden_states = outputs.last_hidden_state ```""" - if "pixel_values" in kwargs: - warnings.warn( - "The `pixel_values` argument is deprecated and will be removed in v4.47, use `input_ids` instead.", - FutureWarning, - ) - - if input_ids is not None: - raise ValueError( - "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`." - ) - - input_ids = kwargs.pop("pixel_values") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -865,19 +851,6 @@ def forward( ... ax.imshow(img) ```""" - if "pixel_values" in kwargs: - warnings.warn( - "The `pixel_values` argument is deprecated and will be removed in v4.47, use `input_ids` instead.", - FutureWarning, - ) - - if input_ids is not None: - raise ValueError( - "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`." - ) - - input_ids = kwargs.pop("pixel_values") - return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( @@ -1002,19 +975,6 @@ def forward( >>> logits = outputs.logits ```""" - if "pixel_values" in kwargs: - warnings.warn( - "The `pixel_values` argument is deprecated and will be removed in v4.47, use `input_ids` instead.", - FutureWarning, - ) - - if input_ids is not None: - raise ValueError( - "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`." - ) - - input_ids = kwargs.pop("pixel_values") - return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index afcc4c9bf149..67500d82eff0 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -693,7 +694,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.attention( @@ -1069,7 +1070,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1084,7 +1085,7 @@ def forward( the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: + past_key_values (`Cache` of length `config.n_layers` with each tuple having 4 tensors of: shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index f43473370a20..cdd9824690da 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -27,6 +27,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -554,7 +555,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.attention( @@ -1030,7 +1031,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1045,7 +1046,7 @@ def forward( the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: + past_key_values (`Cache` of length `config.n_layers` with each tuple having 4 tensors of: shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key @@ -1603,7 +1604,7 @@ def forward( logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + "Using processors without these attributes in the config is deprecated and will throw an error in v4.54." ) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) attention_mask = torch.cat( @@ -1732,7 +1733,7 @@ def generate( logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + "Using processors without these attributes in the config is deprecated and will throw an error in v4.54." ) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) attention_mask = torch.cat( diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index 02b6bd5d0689..e6f32896084e 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -492,7 +492,7 @@ def forward( logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + "Using processors without these attributes in the config is deprecated and will throw an error in v4.54." ) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) attention_mask = torch.cat( @@ -621,7 +621,7 @@ def generate( logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + "Using processors without these attributes in the config is deprecated and will throw an error in v4.54." ) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) attention_mask = torch.cat( diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index 044b78996c8b..e2174f248aaf 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -147,7 +147,7 @@ def __call__( logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + "Using processors without these attributes in the config is deprecated and will throw an error in v4.54." ) # cast to desired return tensors type after concatenating diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 9d7de418f920..68074964c415 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -28,6 +28,7 @@ import torch.nn as nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -567,7 +568,7 @@ def forward(self, image_features): ) class InternVLModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -670,7 +671,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -796,7 +797,7 @@ class InternVLCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -888,7 +889,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 84edfe9e2ed8..1c29907bb16e 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from ...activations import ACT2FN +from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling @@ -600,7 +601,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index ed8ea1787443..9c31d8530d2e 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -118,7 +118,7 @@ class JanusBaseModelOutputWithPast(ModelOutput): If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, @@ -153,7 +153,7 @@ class JanusCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index acfe7dbd48ae..f9bd85898ce3 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -548,7 +548,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value, router_logits + return attn_output, attn_weights, router_logits class JetMoeSdpaAttention(JetMoeAttention): @@ -636,7 +636,7 @@ def forward( attn_output = self.experts.reduce(attn_output, topo_info) attn_output = attn_output.view(bsz, q_len, -1) - return attn_output, None, past_key_value, router_logits + return attn_output, None, router_logits class JetMoeFlashAttention2(JetMoeAttention): @@ -756,7 +756,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value, router_logits + return attn_output, attn_weights, router_logits JETMOE_ATTENTION_CLASSES = { @@ -794,7 +794,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]: # Self Attention - attn_output, self_attn_weights, present_key_value, attn_router_logits = self.self_attention( + attn_output, self_attn_weights, attn_router_logits = self.self_attention( hidden_states=self.input_layernorm(hidden_states), attention_mask=attention_mask, position_ids=position_ids, @@ -813,9 +813,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += attn_router_logits, mlp_router_logits @@ -924,19 +921,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -965,7 +955,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -983,9 +972,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -998,13 +984,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, @@ -1181,7 +1163,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 82fe496ea257..a067580cf08d 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -490,7 +490,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->KyutaiSpeechToText @@ -612,7 +612,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->KyutaiSpeechToText @@ -707,7 +707,7 @@ def forward( attn_output = self.o_proj(attn_output, cache_position) # Ignore copy - return attn_output, None, past_key_value + return attn_output, None KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES = { @@ -769,7 +769,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -794,9 +794,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -855,13 +852,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - return_legacy_cache = False # noqa: F841 - if ( - use_cache and not isinstance(past_key_values, Cache) and not self.training - ): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True # noqa: F841 - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -880,20 +870,16 @@ def forward( # embed positions hidden_states = inputs_embeds - if ( - use_cache and not isinstance(past_key_values, Cache) and not self.training - ): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -911,9 +897,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -923,15 +906,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 2e6008e46453..653006d631dc 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -728,7 +728,7 @@ class Llama4CausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -1280,7 +1280,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 62d80fa5f08f..b7b67a8a776a 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput @@ -43,7 +44,7 @@ ) class LlavaModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -69,7 +70,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -240,7 +241,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -395,7 +396,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 50d67d9cbc43..3d1cfc61e6ff 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -27,7 +27,6 @@ ProcessingKwargs, ProcessorMixin, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging @@ -133,9 +132,6 @@ def __call__( if images is None and text is None: raise ValueError("You have to specify at least one of `images` or `text`.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) - output_kwargs = self._merge_kwargs( LlavaProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 5d74255fe65b..1c33fb8a047a 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -152,7 +153,7 @@ def unpad_image(tensor, original_size): ) class LlavaNextModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -178,7 +179,7 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -435,7 +436,7 @@ def forward( image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -605,7 +606,7 @@ def forward( image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 8de438cb8a0c..05efc60fcd2e 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -28,7 +28,6 @@ ProcessingKwargs, ProcessorMixin, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging @@ -136,8 +135,6 @@ def __call__( """ if images is None and text is None: raise ValueError("You have to specify at least images or text.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( LlavaNextProcessorKwargs, diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py index 94d2cac4c89e..32fa77f97c06 100644 --- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PretrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 37e11e262e8a..cf6435e1c7c9 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -28,6 +28,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -50,7 +51,7 @@ ) class LlavaNextVideoModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -81,7 +82,7 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -486,7 +487,7 @@ def forward( image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -743,7 +744,7 @@ def forward( image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index e8d335ce5e85..713b7f979bb2 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -29,6 +29,7 @@ image_size_to_num_patches, ) +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack @@ -183,7 +184,7 @@ def __init__( class LlavaNextVideoModelOutputWithPast(LlavaNextModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -206,7 +207,7 @@ class LlavaNextVideoCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -406,7 +407,7 @@ def forward( image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -545,7 +546,7 @@ def forward( image_sizes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index c19936d00f43..dd039ff02483 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -23,7 +23,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging from ...video_utils import VideoInput @@ -157,8 +157,6 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( LlavaNextVideoProcessorKwargs, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index d07e7c28b6d8..6552f840a7b6 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -28,6 +28,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -56,7 +57,7 @@ ) class LlavaOnevisionModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -87,7 +88,7 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -502,7 +503,7 @@ def forward( image_sizes_videos: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -785,7 +786,7 @@ def forward( image_sizes_videos: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index af9485b31537..196cbf8e103e 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -32,6 +32,7 @@ unpad_image, ) +from ...cache_utils import Cache from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs, group_images_by_shape, reorder_images from ...image_utils import ( @@ -485,7 +486,7 @@ def forward( image_sizes_videos: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -635,7 +636,7 @@ def forward( image_sizes_videos: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 081869ec8fc5..537b816b79f8 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -467,13 +467,16 @@ def forward( query_states = self.q(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: @@ -542,7 +545,7 @@ def forward( attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output, past_key_value, position_bias) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -733,8 +736,10 @@ def unshape(states): attn_output = attn_output[:, :seq_length, :] attn_output = self.o(attn_output) - present_key_value_state = None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = ( + attn_output, + position_bias, + ) if output_attentions: outputs = outputs + (attn_weights,) @@ -996,8 +1001,7 @@ def unshape(states): attn_output = attn_output[:, :seq_length, :] attn_output = self.o(attn_output) - present_key_value_state = None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -1194,8 +1198,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -1216,7 +1220,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = cross_attention_outputs[:2] + hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -1224,7 +1228,7 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] + attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -1234,14 +1238,9 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if use_cache: - outputs = outputs + (past_key_value,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return ( + (hidden_states,) + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) @auto_docstring @@ -1426,23 +1425,12 @@ def forward( batch_size, seq_length = input_shape - # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if self.is_decoder and (use_cache or past_key_values is not None): - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() elif not self.is_decoder: # do not pass cache object down the line for encoder stack # it messes indexing later in decoder-stack because cache object is modified in-place @@ -1464,7 +1452,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used @@ -1519,23 +1509,21 @@ def forward( ) # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] + position_bias = layer_outputs[1] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) + all_attentions = all_attentions + (layer_outputs[2],) if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1544,18 +1532,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -1564,7 +1546,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -1719,12 +1701,12 @@ def __init__(self, config: LongT5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = LongT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = LongT5Stack(decoder_config, self.shared) @@ -1769,7 +1751,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -1920,12 +1902,12 @@ def __init__(self, config: LongT5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = LongT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = LongT5Stack(decoder_config, self.shared) @@ -1970,7 +1952,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -2167,7 +2149,7 @@ def __init__(self, config: LongT5Config): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = LongT5Stack(encoder_config, self.shared) # Initialize weights and apply final processing diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 8b8ce850339a..df2f6cdaa96c 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -956,7 +956,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1005,7 +1005,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1234,7 +1234,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1362,7 +1362,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 0d7fa3a7b9cc..c4393a3948bd 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -900,7 +900,7 @@ def forward( encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -949,7 +949,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1222,7 +1222,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Union[tuple[torch.Tensor], BaseModelOutput]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1479,7 +1479,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Union[tuple[torch.Tensor], BaseModelOutput]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1660,7 +1660,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 361e1fabda9d..861eeaf68ec0 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -944,7 +944,7 @@ def forward( encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -993,7 +993,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1218,7 +1218,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1374,7 +1374,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1863,7 +1863,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index af303825d2ef..23208f3006ed 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -697,7 +697,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi @@ -814,7 +814,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi @@ -904,7 +904,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None MIMI_ATTENTION_CLASSES = { @@ -962,7 +962,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -985,9 +985,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1094,16 +1091,12 @@ def forward( ) use_cache = False - if use_cache and not isinstance(past_key_values, Cache): - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1126,7 +1119,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -1144,9 +1136,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1154,14 +1143,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 8379e54e9520..2157315238ae 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -872,7 +872,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 9c9f10fde59a..d319e973ac00 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -134,7 +135,7 @@ class Mistral3CausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -161,7 +162,7 @@ class Mistral3CausalLMOutputWithPast(ModelOutput): ) class Mistral3ModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -286,7 +287,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -434,7 +435,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 9063e3a52e95..5b5f27579cc9 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -19,6 +19,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack from ...utils import is_torchdynamo_compiling, logging @@ -181,7 +182,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -277,7 +278,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c9e953b8c936..32f8f9a84b4b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -627,7 +627,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 04b95de41900..de02a2a833bc 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -364,7 +364,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index e2d70b6d9b1d..0f68c2d03d7e 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -532,7 +532,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -647,7 +647,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText @@ -724,7 +724,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -748,9 +748,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -787,7 +784,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states, attn_weights = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, @@ -810,9 +807,6 @@ def forward( if output_attentions: outputs += (attn_weights,) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1386,7 +1380,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1436,9 +1429,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1448,13 +1438,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -1657,7 +1647,7 @@ def forward( cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 1835f55aaec4..3b0da20ad203 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -129,6 +129,7 @@ class ModernBertConfig(PretrainedConfig): ```""" model_type = "modernbert" + attribute_map = {"rope_theta": "global_rope_theta"} keys_to_ignore_at_inference = ["past_key_values"] def __init__( diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 19a8e68d0e65..f4fd7cf37b41 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -241,7 +241,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class ModernBertRotaryEmbedding(nn.Module): - def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None): + def __init__(self, config: ModernBertConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): @@ -253,7 +253,8 @@ def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Opti self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base) + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -461,11 +462,9 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): else: self.local_attention = (-1, -1) - rope_theta = config.global_rope_theta max_position_embeddings = config.max_position_embeddings if self.local_attention != (-1, -1): - if config.local_rope_theta is not None: - rope_theta = config.local_rope_theta + rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta max_position_embeddings = config.local_attention if config._attn_implementation == "flash_attention_2": @@ -473,7 +472,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta) + self.rotary_emb = ModernBertRotaryEmbedding(config=config) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 8cd853aa5cc3..94e45fcc5a6d 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -156,6 +156,7 @@ class ModernBertConfig(PretrainedConfig): ```""" model_type = "modernbert" + attribute_map = {"rope_theta": "global_rope_theta"} keys_to_ignore_at_inference = ["past_key_values"] def __init__( @@ -504,9 +505,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): - def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None): - super().__init__(self, config=config, device=device) - inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base) + pass def eager_attention_forward( @@ -663,11 +662,9 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): else: self.local_attention = (-1, -1) - rope_theta = config.global_rope_theta max_position_embeddings = config.max_position_embeddings if self.local_attention != (-1, -1): - if config.local_rope_theta is not None: - rope_theta = config.local_rope_theta + rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta max_position_embeddings = config.local_attention if config._attn_implementation == "flash_attention_2": @@ -675,7 +672,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta) + self.rotary_emb = ModernBertRotaryEmbedding(config=config) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index ea888df2aab8..9b55fa8c961b 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -414,7 +414,7 @@ def forward( past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index f2ac060dc22f..7180d35e8e6e 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -449,7 +449,7 @@ def forward( past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index a0cf3412900a..45bea58cdf52 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -32,7 +32,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ..auto.modeling_auto import AutoModel from .configuration_moshi import MoshiConfig, MoshiDepthConfig @@ -113,7 +113,7 @@ class MoshiCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -141,7 +141,7 @@ class MoshiConditionalGenerationOutputWithPast(ModelOutput): Text language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the text language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -492,7 +492,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi @@ -614,7 +614,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi @@ -709,7 +709,7 @@ def forward( attn_output = self.o_proj(attn_output, cache_position) # Ignore copy - return attn_output, None, past_key_value + return attn_output, None MOSHI_ATTENTION_CLASSES = { @@ -771,7 +771,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -796,9 +796,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -870,7 +867,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, last_hidden_state: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1002,7 +999,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None hidden_states = inputs_embeds for decoder_layer in self.layers: if output_hidden_states: @@ -1020,9 +1016,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1030,7 +1023,6 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None logits = self.lm_heads(hidden_states, cache_position) loss = None @@ -1045,13 +1037,15 @@ def forward( loss = loss_fct(logits.reshape(-1, self.config.audio_vocab_size), labels) if not return_dict: - return tuple(v for v in [loss, logits, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [loss, logits, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=next_cache, - hidden_states=all_hidden_states, + past_key_values=past_key_values, + hidden_states=past_key_values, attentions=all_self_attns, ) @@ -1269,13 +1263,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - return_legacy_cache = False # noqa: F841 - if ( - use_cache and not isinstance(past_key_values, Cache) and not self.training - ): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True # noqa: F841 - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -1294,20 +1281,16 @@ def forward( # embed positions hidden_states = inputs_embeds - if ( - use_cache and not isinstance(past_key_values, Cache) and not self.training - ): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -1325,9 +1308,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1337,15 +1317,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -1604,10 +1582,6 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not is_torchdynamo_compiling(): - logger.warning_once( - "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" - ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1694,7 +1668,7 @@ def forward( user_audio_codes: Optional[torch.Tensor] = None, moshi_input_values: Optional[torch.FloatTensor] = None, moshi_audio_codes: Optional[torch.Tensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, text_labels: Optional[torch.LongTensor] = None, audio_labels: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 5584b2ee8255..0341c3afcfde 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -365,13 +365,16 @@ def forward( query_states = self.q(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: @@ -440,7 +443,7 @@ def forward( attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output, past_key_value, position_bias) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -563,8 +566,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -588,7 +591,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states, past_key_value = cross_attention_outputs[:2] + hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -600,7 +603,7 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] + attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -616,12 +619,9 @@ def forward( outputs = (hidden_states,) - if use_cache: - outputs = outputs + (past_key_value,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return ( + outputs + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): @@ -993,23 +993,12 @@ def forward( if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if self.is_decoder and (use_cache or past_key_values is not None): - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() elif not self.is_decoder: # do not pass cache object down the line for encoder stack # it messes indexing later in decoder-stack because cache object is modified in-place @@ -1031,7 +1020,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) elif attention_mask is not None: @@ -1105,24 +1096,19 @@ def forward( cache_position=cache_position, ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] + position_bias = layer_outputs[1] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) + all_attentions = all_attentions + (layer_outputs[2],) if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -1137,18 +1123,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -1157,7 +1137,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -1330,12 +1310,12 @@ def __init__(self, config: MT5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = MT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = MT5Stack(decoder_config, self.shared) @@ -1420,7 +1400,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -1599,12 +1579,12 @@ def __init__(self, config: MT5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = MT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = MT5Stack(decoder_config, self.shared) @@ -1692,7 +1672,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -2311,12 +2291,12 @@ def __init__(self, config: MT5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = MT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = MT5Stack(decoder_config, self.shared) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index f84096445f45..30b3faa14428 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -275,7 +275,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron @@ -396,7 +396,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron @@ -489,7 +489,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None NEMOTRON_ATTENTION_CLASSES = { @@ -552,7 +552,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -576,9 +576,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -701,7 +698,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -720,9 +716,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -732,11 +725,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 0408ceb62e5a..0fa4d1961abe 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -346,7 +346,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class OlmoeFlashAttention2(OlmoeAttention): @@ -459,7 +459,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class OlmoeSdpaAttention(OlmoeAttention): @@ -554,7 +554,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None OLMOE_ATTENTION_CLASSES = { @@ -666,7 +666,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -690,9 +690,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -788,19 +785,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -824,7 +814,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -844,9 +833,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -859,15 +845,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, @@ -1034,7 +1018,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 33caa5a127b0..de52e7253400 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -198,7 +198,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class OPTDecoderLayer(GradientCheckpointingLayer): @@ -256,7 +256,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, position_ids=position_ids, @@ -299,9 +299,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -517,7 +514,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -550,7 +547,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of @@ -606,16 +603,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if past_key_values is None: - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " - "You should pass an instance of `DynamicCache` instead, e.g. " - "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: @@ -649,7 +638,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask], ["head_mask"]): @@ -684,9 +672,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -700,13 +685,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/src/transformers/models/owlv2/processing_owlv2.py b/src/transformers/models/owlv2/processing_owlv2.py index eddfcb3f7dee..de8f23d3d6b8 100644 --- a/src/transformers/models/owlv2/processing_owlv2.py +++ b/src/transformers/models/owlv2/processing_owlv2.py @@ -28,7 +28,6 @@ ProcessingKwargs, ProcessorMixin, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available @@ -71,8 +70,6 @@ class Owlv2Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "Owlv2ImageProcessor" tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") - # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. - optional_call_args = ["query_images"] def __init__(self, image_processor, tokenizer, **kwargs): super().__init__(image_processor, tokenizer) @@ -82,12 +79,6 @@ def __call__( self, images: Optional[ImageInput] = None, text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, - # The following is to capture `query_images` argument that may be passed as a positional argument. - # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, - # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 - # This behavior is only needed for backward compatibility and will be removed in future versions. - # - *args, audio=None, videos=None, **kwargs: Unpack[Owlv2ProcessorKwargs], @@ -132,7 +123,6 @@ def __call__( Owlv2ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, - **self.prepare_and_validate_optional_call_args(*args), ) query_images = output_kwargs["images_kwargs"].pop("query_images", None) return_tensors = output_kwargs["common_kwargs"]["return_tensors"] @@ -141,8 +131,6 @@ def __call__( raise ValueError( "You have to specify at least one text or query image or image. All three cannot be none." ) - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) data = {} if text is not None: diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 089f1eb26eb8..375402bb0a54 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -28,7 +28,6 @@ ProcessingKwargs, ProcessorMixin, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available @@ -71,8 +70,6 @@ class OwlViTProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "OwlViTImageProcessor" tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") - # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. - optional_call_args = ["query_images"] def __init__(self, image_processor=None, tokenizer=None, **kwargs): feature_extractor = None @@ -96,12 +93,6 @@ def __call__( self, images: Optional[ImageInput] = None, text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, - # The following is to capture `query_images` argument that may be passed as a positional argument. - # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, - # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 - # This behavior is only needed for backward compatibility and will be removed in future versions. - # - *args, audio=None, videos=None, **kwargs: Unpack[OwlViTProcessorKwargs], @@ -146,7 +137,6 @@ def __call__( OwlViTProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, - **self.prepare_and_validate_optional_call_args(*args), ) query_images = output_kwargs["images_kwargs"].pop("query_images", None) return_tensors = output_kwargs["common_kwargs"]["return_tensors"] @@ -155,8 +145,6 @@ def __call__( raise ValueError( "You have to specify at least one text or query image or image. All three cannot be none." ) - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) data = {} if text is not None: diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 847f31da71e2..9ee315b8beba 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -50,7 +50,7 @@ ) class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -76,7 +76,7 @@ class PaliGemmaCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index d724a8e0f146..b4c8e555b52d 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -29,7 +29,6 @@ ProcessorMixin, TextKwargs, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput from ...utils import logging @@ -216,8 +215,6 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **labels** -- Labels compatible with training if `suffix` is not None """ - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( PaliGemmaProcessorKwargs, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 24c61a746605..d8f8a511bc0c 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -998,7 +998,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1630,7 +1630,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index c3e9d0fa5a96..500150bba4d3 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1254,7 +1254,7 @@ def forward( [What are attention masks?](../glossary#attention-mask) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 6b8da44b7424..5e62375845bf 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -298,7 +298,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class PersimmonDecoderLayer(GradientCheckpointingLayer): @@ -352,7 +352,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -378,9 +378,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -453,7 +450,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -477,19 +474,12 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -514,7 +504,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -534,9 +523,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -546,13 +532,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -727,7 +709,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 67c8129c0fa4..f3ad3bafb822 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1657,7 +1657,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, @@ -1786,7 +1786,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 2477f52f68ac..4ccba03cce43 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -21,7 +21,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask @@ -1504,7 +1504,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, @@ -1610,7 +1610,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 09528a8f8eb6..e8f9e1060026 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -321,7 +321,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class PhimoeFlashAttention2(PhimoeAttention): @@ -419,7 +419,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class PhimoeSdpaAttention(PhimoeAttention): @@ -507,7 +507,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None PHIMOE_ATTENTION_CLASSES = { @@ -851,7 +851,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -874,9 +874,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -950,7 +947,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -979,19 +976,12 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1016,7 +1006,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -1036,9 +1025,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1051,13 +1037,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, @@ -1264,7 +1246,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 6b90ae80d7c7..93c281dc5bab 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -759,13 +759,16 @@ def forward( query_states = self.query(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: @@ -834,7 +837,7 @@ def forward( attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.output(attn_output) - outputs = (attn_output, past_key_value, position_bias) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -960,8 +963,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -981,7 +984,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states, past_key_value = cross_attention_outputs[:2] + hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -989,7 +992,7 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] + attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.mlp(hidden_states) @@ -1001,12 +1004,7 @@ def forward( outputs = (hidden_states,) - if use_cache: - outputs = outputs + (past_key_value,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs + return outputs + attention_outputs @auto_docstring( @@ -1092,7 +1090,7 @@ def forward( inputs_embeds: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1162,23 +1160,11 @@ def forward( batch_size, seq_length = input_shape - # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if use_cache or past_key_values is not None: - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() past_key_values_length = 0 if cache_position is not None: @@ -1203,7 +1189,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) else: @@ -1254,24 +1242,19 @@ def forward( cache_position=cache_position, ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] + position_bias = layer_outputs[1] if encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) + all_attentions = all_attentions + (layer_outputs[2],) if encoder_hidden_states is not None: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1290,19 +1273,13 @@ def forward( loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ loss, logits, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -1312,7 +1289,7 @@ def forward( return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -1494,7 +1471,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, labels: Optional[torch.LongTensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 849fa3e011ed..16ebaa0d5524 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -27,7 +27,6 @@ ProcessingKwargs, ProcessorMixin, Unpack, - _validate_images_text_input_order, ) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_vision_available, logging @@ -157,8 +156,6 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( PixtralProcessorKwargs, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index bcad5f8f7253..3b071a1fe3d9 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -876,7 +876,7 @@ def forward( encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -925,7 +925,7 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. @@ -1169,7 +1169,7 @@ def forward( decoder_head_mask: Optional[torch.LongTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1326,7 +1326,7 @@ def forward( decoder_head_mask: Optional[torch.LongTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.Tensor] = None, @@ -1689,7 +1689,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index f998a0136c21..22200c0bbf65 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -325,7 +325,7 @@ def forward( decoder_head_mask: Optional[torch.LongTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -482,7 +482,7 @@ def forward( decoder_head_mask: Optional[torch.LongTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[list[torch.FloatTensor]] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5c4285afe728..dca3191be9e0 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -309,13 +309,16 @@ def forward( query_states = self.q(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: @@ -384,7 +387,7 @@ def forward( attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output, past_key_value, position_bias) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -509,8 +512,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -534,7 +537,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states, past_key_value = cross_attention_outputs[:2] + hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -546,7 +549,7 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] + attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -562,12 +565,9 @@ def forward( outputs = (hidden_states,) - if use_cache: - outputs = outputs + (past_key_value,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return ( + outputs + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) @auto_docstring @@ -741,23 +741,12 @@ def forward( if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if self.is_decoder and (use_cache or past_key_values is not None): - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() elif not self.is_decoder: # do not pass cache object down the line for encoder stack # it messes indexing later in decoder-stack because cache object is modified in-place @@ -779,7 +768,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) else: @@ -830,24 +821,19 @@ def forward( cache_position=cache_position, ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] + position_bias = layer_outputs[1] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) + all_attentions = all_attentions + (layer_outputs[2],) if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -856,18 +842,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -876,7 +856,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -1042,13 +1022,13 @@ def __init__(self, config: Pop2PianoConfig): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = Pop2PianoStack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = Pop2PianoStack(decoder_config, self.shared) @@ -1137,7 +1117,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, input_features: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index fdc248287a08..99596434270f 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1467,13 +1467,13 @@ def __init__(self, config: ProphetNetConfig): self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) encoder_config = copy.deepcopy(config) - encoder_config.is_encoder_decoder = False encoder_config.use_cache = False + encoder_config.tie_encoder_decoder = False self.encoder = ProphetNetEncoder(encoder_config, self.word_embeddings) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings) # Initialize weights and apply final processing diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index 51897b40681a..552249f0ed2b 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...utils import logging diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index db2ac33eae56..2ac680b7fec8 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -531,7 +531,7 @@ class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -1438,7 +1438,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Qwen2MLP(nn.Module): @@ -1511,7 +1511,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -1535,9 +1535,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1576,7 +1573,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1649,7 +1646,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -1669,9 +1665,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1681,13 +1674,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -1817,7 +1810,7 @@ def forward( feature_attention_mask: Optional[torch.Tensor] = None, audio_feature_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -2082,7 +2075,7 @@ class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -2138,7 +2131,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -2211,7 +2204,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -2231,9 +2223,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -2243,13 +2232,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -2294,7 +2283,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, thinker_reply_part: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index c64e89a9ef43..bf7f9930f64a 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -42,6 +42,7 @@ from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...generation import GenerationMixin from ...modeling_flash_attention_utils import is_flash_attn_available @@ -1572,7 +1573,7 @@ class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -2262,7 +2263,7 @@ def forward( feature_attention_mask: Optional[torch.Tensor] = None, audio_feature_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -2527,7 +2528,7 @@ class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -2597,7 +2598,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, thinker_reply_part: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index aa8e0e448758..d6efd976452c 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -513,7 +513,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) ) class Qwen2_5_VLModelOutputWithPast(ModelOutput): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -716,7 +716,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer): @@ -775,7 +775,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -799,9 +799,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -839,7 +836,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -912,7 +909,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -932,9 +928,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -944,13 +937,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -1206,7 +1199,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1371,7 +1364,7 @@ class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -1445,7 +1438,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 07f41356e800..343cdc2620b2 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -44,6 +44,7 @@ from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor from ...activations import ACT2FN +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput @@ -548,7 +549,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -711,7 +712,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index dc9f775ae995..c3cce2f9b61c 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -362,7 +362,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe @@ -478,7 +478,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe @@ -569,7 +569,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None QWEN2MOE_ATTENTION_CLASSES = { @@ -702,7 +702,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -731,9 +731,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -798,7 +795,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -825,19 +822,12 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -863,7 +853,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -883,9 +872,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -898,13 +884,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - return MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, @@ -1110,7 +1092,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 9c9aaf646763..062561bb523a 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -59,7 +59,7 @@ ) class Qwen2VLModelOutputWithPast(ModelOutput): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -88,7 +88,7 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -563,7 +563,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Qwen2VLDecoderLayer(GradientCheckpointingLayer): @@ -622,7 +622,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -646,9 +646,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -814,7 +811,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -887,7 +884,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -907,9 +903,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -919,13 +912,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -1146,7 +1139,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1346,7 +1339,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index c1201269ba21..75d0ea87424e 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -650,7 +650,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index 5618f2d0aff6..c71c1cae2c2f 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack @@ -191,7 +192,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 77ddfdaa01ad..299c90e2cfe2 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -21,6 +21,7 @@ import torch from torch import nn +from ...cache_utils import EncoderDecoderCache from ...configuration_utils import PretrainedConfig from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput @@ -1201,6 +1202,8 @@ def _reorder_stacked(hidden_states, new_order): reordered_past += ( tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past), ) + if isinstance(past_key_values, EncoderDecoderCache): + reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past) return reordered_past @@ -1592,6 +1595,14 @@ def extend_enc_output(tensor, num_beams=None): f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}" ) + # Auxiliary functions for beam search + def _temporary_reorder_cache(self, past_key_values, beam_idx): + # RAG should always use the legacy path even though the LM backbone (T5) uses new cache format + # because RAG expands input for doc-size internally. TODO: raushan, remove me when all models support + # new cache format + past_key_values = self._reorder_cache(past_key_values, beam_idx) + return past_key_values + def get_input_embeddings(self): return self.rag.generator.get_input_embeddings() diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index 379ccbe0e217..c1c7f22cc777 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -67,13 +67,6 @@ class SamProcessor(ProcessorMixin): attributes = ["image_processor"] image_processor_class = "SamImageProcessor" - # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. - optional_call_args = [ - "segmentation_maps", - "input_points", - "input_labels", - "input_boxes", - ] def __init__(self, image_processor): super().__init__(image_processor) @@ -82,13 +75,6 @@ def __init__(self, image_processor): def __call__( self, images: Optional[ImageInput] = None, - # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes` - # arguments that may be passed as a positional argument. - # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, - # or this conversation for more context: - # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 - # This behavior is only needed for backward compatibility and will be removed in future versions. - *args, # to be deprecated text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, audio: Optional[AudioInput] = None, video: Optional[VideoInput] = None, @@ -102,7 +88,6 @@ def __call__( SamProcessorKwargs, tokenizer_init_kwargs={}, **kwargs, - **self.prepare_and_validate_optional_call_args(*args), ) input_points = output_kwargs["images_kwargs"].pop("input_points", None) input_labels = output_kwargs["images_kwargs"].pop("input_labels", None) diff --git a/src/transformers/models/sam_hq/processing_samhq.py b/src/transformers/models/sam_hq/processing_samhq.py index bd19784f5fba..97dbcdfab638 100644 --- a/src/transformers/models/sam_hq/processing_samhq.py +++ b/src/transformers/models/sam_hq/processing_samhq.py @@ -65,13 +65,6 @@ class SamHQProcessor(ProcessorMixin): attributes = ["image_processor"] image_processor_class = "SamImageProcessor" - optional_call_args = [ - "segmentation_maps", - "input_points", - "input_labels", - "input_boxes", - ] - def __init__(self, image_processor): super().__init__(image_processor) # Ensure image_processor is properly initialized @@ -84,13 +77,6 @@ def __init__(self, image_processor): def __call__( self, images: Optional[ImageInput] = None, - # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes` - # arguments that may be passed as a positional argument. - # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, - # or this conversation for more context: - # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 - # This behavior is only needed for backward compatibility and will be removed in future versions. - *args, # to be deprecated text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, audio: Optional[AudioInput] = None, video: Optional[VideoInput] = None, @@ -104,7 +90,6 @@ def __call__( SamHQProcessorKwargs, tokenizer_init_kwargs={}, **kwargs, - **self.prepare_and_validate_optional_call_args(*args), ) input_points = output_kwargs["images_kwargs"].pop("input_points", None) diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 7053db9b171b..befec29d9ded 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -26,7 +26,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -463,7 +463,7 @@ class SmolVLMBaseModelOutputWithPast(ModelOutput): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, @@ -687,7 +687,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -784,7 +784,7 @@ class SmolVLMCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see @@ -861,7 +861,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, diff --git a/src/transformers/models/smolvlm/modular_smolvlm.py b/src/transformers/models/smolvlm/modular_smolvlm.py index e85074efb88a..dfda67472ca9 100644 --- a/src/transformers/models/smolvlm/modular_smolvlm.py +++ b/src/transformers/models/smolvlm/modular_smolvlm.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from torch import nn -from ...cache_utils import DynamicCache +from ...cache_utils import Cache, DynamicCache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging @@ -270,7 +270,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index b7e99abe6350..0aef9d3ab7d6 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -301,7 +301,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class StableLmSdpaAttention(StableLmAttention): @@ -409,7 +409,7 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None class StableLmFlashAttention2(StableLmAttention): @@ -512,7 +512,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights ATTENTION_CLASSES = { @@ -576,7 +576,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - self_attn_output, self_attn_weights, present_key_value = self.self_attn( + self_attn_output, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -608,9 +608,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -682,7 +679,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -705,19 +702,12 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -742,7 +732,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -761,9 +750,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -773,13 +759,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -956,7 +938,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index b0273c8a4a33..63f1bfadeebe 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -502,13 +502,16 @@ def forward( query_states = self.q(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: @@ -577,7 +580,7 @@ def forward( attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output, past_key_value, position_bias) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -705,8 +708,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -727,7 +730,7 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = cross_attention_outputs[:2] + hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -735,7 +738,7 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] + attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states, output_router_logits) @@ -752,12 +755,9 @@ def forward( outputs = (hidden_states,) - if use_cache: - outputs = outputs + (past_key_value,) + attention_outputs + (router_tuple,) - else: - outputs = outputs + attention_outputs + (router_tuple,) - - return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) + return ( + outputs + attention_outputs + (router_tuple,) + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) @auto_docstring @@ -949,23 +949,12 @@ def forward( if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if self.is_decoder and (use_cache or past_key_values is not None): - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() elif not self.is_decoder: # do not pass cache object down the line for encoder stack # it messes indexing later in decoder-stack because cache object is modified in-place @@ -987,7 +976,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) else: @@ -1045,24 +1036,19 @@ def forward( router_probs = layer_outputs[-1] layer_outputs = layer_outputs[:-1] - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] + position_bias = layer_outputs[1] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) + all_attentions = all_attentions + (layer_outputs[2],) if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) if output_router_logits: all_router_probs = all_router_probs + (router_probs,) @@ -1074,18 +1060,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -1095,7 +1075,7 @@ def forward( ) return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -1248,12 +1228,12 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = SwitchTransformersStack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False self.decoder = SwitchTransformersStack(decoder_config, self.shared) # Initialize weights and apply final processing @@ -1300,7 +1280,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -1465,12 +1445,12 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = SwitchTransformersStack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = SwitchTransformersStack(decoder_config, self.shared) @@ -1521,7 +1501,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 2a1a84b81523..216ea793fd87 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -490,13 +490,16 @@ def forward( query_states = self.q(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: @@ -565,7 +568,7 @@ def forward( attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output, past_key_value, position_bias) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -685,8 +688,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -710,7 +713,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states, past_key_value = cross_attention_outputs[:2] + hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -722,7 +725,7 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] + attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -738,12 +741,9 @@ def forward( outputs = (hidden_states,) - if use_cache: - outputs = outputs + (past_key_value,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return ( + outputs + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class T5ClassificationHead(nn.Module): @@ -1006,23 +1006,12 @@ def forward( if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if self.is_decoder and (use_cache or past_key_values is not None): - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() elif not self.is_decoder: # do not pass cache object down the line for encoder stack # it messes indexing later in decoder-stack because cache object is modified in-place @@ -1044,7 +1033,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) elif attention_mask is not None: @@ -1118,24 +1109,19 @@ def forward( cache_position=cache_position, ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] + position_bias = layer_outputs[1] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) + all_attentions = all_attentions + (layer_outputs[2],) if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -1150,18 +1136,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -1170,7 +1150,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -1325,12 +1305,12 @@ def __init__(self, config: T5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = T5Stack(decoder_config, self.shared) @@ -1412,7 +1392,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -1574,12 +1554,12 @@ def __init__(self, config: T5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = T5Stack(decoder_config, self.shared) @@ -1663,7 +1643,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -2256,12 +2236,12 @@ def __init__(self, config: T5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = T5Stack(decoder_config, self.shared) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 8d4e368e945b..23af62e4a1dc 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -75,7 +75,7 @@ class BaseModelOutputWithAttentionMask(ModelOutput): Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, @@ -98,7 +98,7 @@ class BaseModelOutputWithAttentionMask(ModelOutput): last_hidden_state: Optional[torch.FloatTensor] = None attention_mask: Optional[torch.FloatTensor] = None - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None cross_attentions: Optional[tuple[torch.FloatTensor]] = None @@ -588,13 +588,16 @@ def forward( query_states = self.q(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: @@ -663,7 +666,7 @@ def forward( attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - outputs = (attn_output, past_key_value, position_bias) + outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -788,8 +791,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states, past_key_value = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -813,7 +816,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states, past_key_value = cross_attention_outputs[:2] + hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -825,7 +828,7 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] + attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -841,12 +844,9 @@ def forward( outputs = (hidden_states,) - if use_cache: - outputs = outputs + (past_key_value,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return ( + outputs + attention_outputs + ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class UdopCellEmbeddings(nn.Module): @@ -1225,23 +1225,12 @@ def forward( if use_cache is True: assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" - # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if self.is_decoder and (use_cache or past_key_values is not None): - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() elif not self.is_decoder: # do not pass cache object down the line for encoder stack # it messes indexing later in decoder-stack because cache object is modified in-place @@ -1263,7 +1252,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) else: @@ -1310,24 +1301,21 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - if use_cache is False: # MP fixes - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, next_decoder_cache = layer_outputs[:2] + + hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), # (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[2] + position_bias = layer_outputs[1] if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1336,19 +1324,13 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, attention_mask, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -1359,7 +1341,7 @@ def forward( return BaseModelOutputWithAttentionMask( last_hidden_state=hidden_states, attention_mask=attention_mask, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -1512,12 +1494,12 @@ def __init__(self, config): encoder_config = deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) decoder_config = deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = UdopStack(decoder_config, self.shared) @@ -1550,7 +1532,7 @@ def forward( decoder_attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, encoder_outputs: Optional[Tensor] = None, - past_key_values: Optional[Tensor] = None, + past_key_values: Optional[Cache] = None, head_mask: Optional[Tensor] = None, decoder_inputs_embeds: Optional[Tensor] = None, decoder_head_mask: Optional[Tensor] = None, @@ -1708,12 +1690,12 @@ def __init__(self, config): encoder_config = deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) decoder_config = deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = UdopStack(decoder_config, self.shared) @@ -1755,7 +1737,7 @@ def forward( decoder_attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, encoder_outputs: Optional[Tensor] = None, - past_key_values: Optional[Tensor] = None, + past_key_values: Optional[Cache] = None, head_mask: Optional[Tensor] = None, decoder_inputs_embeds: Optional[Tensor] = None, decoder_head_mask: Optional[Tensor] = None, diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py index c29bb25d777a..1c9b7ee2971d 100644 --- a/src/transformers/models/udop/processing_udop.py +++ b/src/transformers/models/udop/processing_udop.py @@ -77,8 +77,6 @@ class UdopProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "LayoutLMv3ImageProcessor" tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast") - # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. - optional_call_args = ["text_pair"] def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) @@ -87,12 +85,6 @@ def __call__( self, images: Optional[ImageInput] = None, text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, - # The following is to capture `text_pair` argument that may be passed as a positional argument. - # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, - # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 - # This behavior is only needed for backward compatibility and will be removed in future versions. - # - *args, audio=None, videos=None, **kwargs: Unpack[UdopProcessorKwargs], @@ -115,7 +107,6 @@ def __call__( UdopProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, - **self.prepare_and_validate_optional_call_args(*args), ) boxes = output_kwargs["text_kwargs"].pop("boxes", None) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 2b1f650c6789..4b9d96db2f21 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -274,13 +274,16 @@ def forward( query_states = self.q(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: @@ -346,7 +349,7 @@ def forward( attn_output = attn_output.view(batch_size, seq_length, -1) attn_output = self.o(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class UMT5LayerSelfAttention(nn.Module): @@ -431,7 +434,7 @@ def forward( output_attentions=False, cache_position=None, ): - hidden_states, self_attn_weights, past_key_value = self.layer[0]( + hidden_states, self_attn_weights = self.layer[0]( hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -449,7 +452,7 @@ def forward( cross_attn_weights = None do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - hidden_states, cross_attn_weights, past_key_value = self.layer[1]( + hidden_states, cross_attn_weights = self.layer[1]( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -472,10 +475,7 @@ def forward( clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = ( - hidden_states, - past_key_value, - ) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) @@ -692,22 +692,12 @@ def forward( raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") # initialize past_key_values - return_legacy_cache = False - return_self_attention_cache = False - if self.is_decoder and (use_cache or past_key_values is not None): - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - elif past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + if self.is_decoder: + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() elif not self.is_decoder: # do not pass cache object down the line for encoder stack # it messes indexing later in decoder-stack because cache object is modified in-place @@ -729,7 +719,9 @@ def forward( attention_mask, inputs_embeds, cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values, output_attentions, ) elif attention_mask is not None: @@ -781,13 +773,10 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[1] - if output_attentions: - all_attentions += (layer_outputs[2],) + all_attentions += (layer_outputs[1],) if self.is_decoder: - all_cross_attentions += (layer_outputs[3],) + all_cross_attentions += (layer_outputs[2],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -796,18 +785,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() - if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + past_key_values, all_hidden_states, all_attentions, all_cross_attentions, @@ -816,7 +799,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -978,12 +961,12 @@ def __init__(self, config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = UMT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = UMT5Stack(decoder_config, self.shared) @@ -1034,7 +1017,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -1193,12 +1176,12 @@ def __init__(self, config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = UMT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = UMT5Stack(decoder_config, self.shared) @@ -1250,7 +1233,7 @@ def forward( decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -1783,12 +1766,12 @@ def __init__(self, config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False + encoder_config.tie_encoder_decoder = False self.encoder = UMT5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False + decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers self.decoder = UMT5Stack(decoder_config, self.shared) diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 3a906c456337..4fa05080ad42 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ModelOutput @@ -43,7 +44,7 @@ ) class VideoLlavaModelOutputWithPast(ModelOutput): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -77,7 +78,7 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -291,7 +292,7 @@ def forward( pixel_values_videos: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, @@ -481,7 +482,7 @@ def forward( pixel_values_videos: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 326fc4a35333..2826e7449ea2 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel @@ -42,7 +43,7 @@ ) class VipLlavaModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -68,7 +69,7 @@ class VipLlavaCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -207,7 +208,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layers: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -348,7 +349,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layers: Optional[Union[int, list[int]]] = None, labels: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/vipllava/modular_vipllava.py b/src/transformers/models/vipllava/modular_vipllava.py index 97458112ad5d..74cefba7a998 100644 --- a/src/transformers/models/vipllava/modular_vipllava.py +++ b/src/transformers/models/vipllava/modular_vipllava.py @@ -27,6 +27,7 @@ ) from ...activations import ACT2FN +from ...cache_utils import Cache from ...utils import auto_docstring, is_torchdynamo_compiling, logging from .configuration_vipllava import VipLlavaConfig @@ -109,7 +110,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layers: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -198,7 +199,7 @@ def forward( pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layers: Optional[Union[int, list[int]]] = None, labels: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py index 0cdf97329c33..8a5320708f16 100644 --- a/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py @@ -20,7 +20,7 @@ from typing import Optional, Union from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput @@ -111,8 +111,6 @@ def __call__( if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( VisionTextDualEncoderProcessorKwargs, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 123376a58a39..f8473cda9fe0 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -316,7 +316,8 @@ def forward( query_states = query_states.view(*q_input_shape) query_states = query_states.transpose(1, 2).contiguous() - if past_key_value is not None: + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache @@ -881,20 +882,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - return_legacy_cache = False - return_self_attention_cache = False - if use_cache or past_key_values is not None: - if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): - return_self_attention_cache = True - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) - elif not isinstance(past_key_values, EncoderDecoderCache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + if use_cache and past_key_values is None: + if self.config.is_encoder_decoder: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + else: + past_key_values = DynamicCache() past_key_values_length = 0 if cache_position is not None: @@ -984,10 +976,6 @@ def forward( all_hidden_states += (hidden_states,) next_cache = past_key_values if use_cache else None - if return_self_attention_cache: - next_cache = past_key_values.self_attention_cache - if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( v @@ -1086,7 +1074,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[Cache]] = None, decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, @@ -1256,7 +1244,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[Cache]] = None, decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -1453,7 +1441,7 @@ def forward( encoder_outputs: Optional[tuple[torch.FloatTensor]] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index f7e8f818aab5..2789e508f6fa 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -125,6 +125,7 @@ class Zamba2Config(PretrainedConfig): ```""" model_type = "zamba2" + attribute_map = {"head_dim": "attention_head_dim"} keys_to_ignore_at_inference = ["past_key_values"] def __init__( diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index b7c391887f1a..5731d56643b0 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -206,11 +206,7 @@ def reset(self): class Zamba2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Zamba2Config, - device=None, - ): + def __init__(self, config: Zamba2Config, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): @@ -222,10 +218,8 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - # we cannot use the config here to parameterize because of a factor 2 for the head_dim - inv_freq, self.attention_scaling = self.rope_init_fn( - device=device, base=config.rope_theta, dim=config.attention_head_dim - ) + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index fc1d76816101..a980d35828d0 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -167,16 +167,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: class Zamba2RotaryEmbedding(LlamaRotaryEmbedding): - def __init__( - self, - config: Zamba2Config, - device=None, - ): - super().__init__(config, device) - # we cannot use the config here to parameterize because of a factor 2 for the head_dim - inv_freq, self.attention_scaling = self.rope_init_fn( - device=device, base=config.rope_theta, dim=config.attention_head_dim - ) + pass class Zamba2Attention(ZambaAttention): diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 838231a420cc..cb58e5585b75 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -33,7 +33,7 @@ from .audio_utils import load_audio from .dynamic_module_utils import custom_object_save from .feature_extraction_utils import BatchFeature -from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image +from .image_utils import ChannelDimension, is_vision_available, load_image from .utils.chat_template_utils import render_jinja_template from .video_utils import VideoMetadata, load_video @@ -1413,64 +1413,6 @@ def validate_init_kwargs(processor_config, valid_kwargs): return unused_kwargs, valid_kwargs - def prepare_and_validate_optional_call_args(self, *args): - """ - Matches optional positional arguments to their corresponding names in `optional_call_args` - in the processor class in the order they are passed to the processor call. - - Note that this should only be used in the `__call__` method of the processors with special - arguments. Special arguments are arguments that aren't `text`, `images`, `audio`, nor `videos` - but also aren't passed to the tokenizer, image processor, etc. Examples of such processors are: - - `CLIPSegProcessor` - - `LayoutLMv2Processor` - - `OwlViTProcessor` - - Also note that passing by position to the processor call is now deprecated and will be disallowed - in future versions. We only have this for backward compatibility. - - Example: - Suppose that the processor class has `optional_call_args = ["arg_name_1", "arg_name_2"]`. - And we define the call method as: - ```python - def __call__( - self, - text: str, - images: Optional[ImageInput] = None, - *arg, - audio=None, - videos=None, - ) - ``` - - Then, if we call the processor as: - ```python - images = [...] - processor("What is common in these images?", images, arg_value_1, arg_value_2) - ``` - - Then, this method will return: - ```python - { - "arg_name_1": arg_value_1, - "arg_name_2": arg_value_2, - } - ``` - which we could then pass as kwargs to `self._merge_kwargs` - """ - if len(args): - warnings.warn( - "Passing positional arguments to the processor call is now deprecated and will be disallowed in v4.47. " - "Please pass all arguments as keyword arguments." - ) - if len(args) > len(self.optional_call_args): - raise ValueError( - f"Expected *at most* {len(self.optional_call_args)} optional positional arguments in processor call" - f"which will be matched with {' '.join(self.optional_call_args)} in the order they are passed." - f"However, got {len(args)} positional arguments instead." - "Please pass all arguments as keyword arguments instead (e.g. `processor(arg_name_1=..., arg_name_2=...))`." - ) - return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)} - @deprecate_kwarg("video_fps", version="4.58", new_name="fps") def apply_chat_template( self, @@ -1721,64 +1663,6 @@ def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", ) -def _validate_images_text_input_order(images, text): - """ - For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped. - This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes. - Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled - in the processor's `__call__` method before calling this method. - """ - - def is_url(val) -> bool: - return isinstance(val, str) and val.startswith("http") - - def _is_valid_images_input_for_processor(imgs): - # If we have an list of images, make sure every image is valid - if isinstance(imgs, (list, tuple)): - for img in imgs: - if not _is_valid_images_input_for_processor(img): - return False - # If not a list or tuple, we have been given a single image or batched tensor of images - elif not (is_valid_image(imgs) or is_url(imgs)): - return False - return True - - def _is_valid_text_input_for_processor(t): - if isinstance(t, str): - # Strings are fine - return True - elif isinstance(t, (list, tuple)): - # List are fine as long as they are... - if len(t) == 0: - # ... not empty - return False - for t_s in t: - return _is_valid_text_input_for_processor(t_s) - return False - - def _is_valid(input, validator): - return validator(input) or input is None - - images_is_valid = _is_valid(images, _is_valid_images_input_for_processor) - images_is_text = _is_valid_text_input_for_processor(images) - - text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) - text_is_images = _is_valid_images_input_for_processor(text) - # Handle cases where both inputs are valid - if images_is_valid and text_is_valid: - return images, text - - # Handle cases where inputs need to and can be swapped - if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images): - logger.warning_once( - "You may have used the wrong order for inputs. `images` should be passed before `text`. " - "The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47." - ) - return text, images - - raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.") - - ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) if ProcessorMixin.push_to_hub.__doc__ is not None: ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( diff --git a/tests/models/owlv2/test_processor_owlv2.py b/tests/models/owlv2/test_processor_owlv2.py index 91043775069f..55dbe51e2a5c 100644 --- a/tests/models/owlv2/test_processor_owlv2.py +++ b/tests/models/owlv2/test_processor_owlv2.py @@ -2,8 +2,6 @@ import tempfile import unittest -import pytest - from transformers import Owlv2Processor from transformers.testing_utils import require_scipy @@ -23,18 +21,3 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): shutil.rmtree(cls.tmpdirname, ignore_errors=True) - - def test_processor_query_images_positional(self): - processor_components = self.prepare_components() - processor = Owlv2Processor(**processor_components) - - image_input = self.prepare_image_inputs() - query_images = self.prepare_image_inputs() - - inputs = processor(None, image_input, query_images) - - self.assertListEqual(list(inputs.keys()), ["query_pixel_values", "pixel_values"]) - - # test if it raises when no input is passed - with pytest.raises(ValueError): - processor() diff --git a/tests/models/owlvit/test_processor_owlvit.py b/tests/models/owlvit/test_processor_owlvit.py index 5f99a2275f1a..f31dbaf9fbcc 100644 --- a/tests/models/owlvit/test_processor_owlvit.py +++ b/tests/models/owlvit/test_processor_owlvit.py @@ -232,21 +232,6 @@ def test_processor_case2(self): with pytest.raises(ValueError): processor() - def test_processor_query_images_positional(self): - processor_components = self.prepare_components() - processor = OwlViTProcessor(**processor_components) - - image_input = self.prepare_image_inputs() - query_images = self.prepare_image_inputs() - - inputs = processor(None, image_input, query_images) - - self.assertListEqual(list(inputs.keys()), ["query_pixel_values", "pixel_values"]) - - # test if it raises when no input is passed - with pytest.raises(ValueError): - processor() - def test_tokenizer_decode(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index 761a785f369f..9427b3771b4a 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -77,58 +77,6 @@ def test_rope_validation(self): self.assertEqual(len(logs.output), 1) self.assertIn(model_specific_kwarg, logs.output[0]) - def test_default_rope_function_bc(self): - config = LlamaConfig() - device = torch_device - - rope_kwargs = { - "rope_type": "default", - "dim": config.hidden_size // config.num_attention_heads, - "max_position_embeddings": config.max_position_embeddings, - "base": config.rope_theta, - } - - rope_fn = ROPE_INIT_FUNCTIONS["default"] - config_freqs = rope_fn(config=config, device=device)[0] - kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0] - torch.testing.assert_close(config_freqs, kwargs_freqs) - - def test_linear_rope_function_bc(self): - config = LlamaConfig() - config.rope_scaling = {"rope_type": "linear", "factor": 10.0} - device = torch_device - - rope_kwargs = { - "rope_type": "linear", - "dim": config.hidden_size // config.num_attention_heads, - "max_position_embeddings": config.max_position_embeddings, - "base": config.rope_theta, - "factor": 10.0, - } - - rope_fn = ROPE_INIT_FUNCTIONS["linear"] - config_freqs = rope_fn(config=config, device=device)[0] - kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0] - torch.testing.assert_close(config_freqs, kwargs_freqs) - - def test_dynamic_rope_function_bc(self): - config = LlamaConfig() - config.rope_scaling = {"rope_type": "dynamic", "factor": 10.0} - device = torch_device - - rope_kwargs = { - "rope_type": "dynamic", - "dim": config.hidden_size // config.num_attention_heads, - "max_position_embeddings": config.max_position_embeddings, - "base": config.rope_theta, - "factor": 10.0, - } - - rope_fn = ROPE_INIT_FUNCTIONS["dynamic"] - config_freqs = rope_fn(config=config, device=device)[0] - kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0] - torch.testing.assert_close(config_freqs, kwargs_freqs) - def test_default_rope_numerically(self): # Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then # multiple RoPE strategies will fail. diff --git a/tests/utils/test_processing_utils.py b/tests/utils/test_processing_utils.py deleted file mode 100644 index 7b32b534a70e..000000000000 --- a/tests/utils/test_processing_utils.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2024 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np - -from transformers import is_torch_available, is_vision_available -from transformers.processing_utils import _validate_images_text_input_order -from transformers.testing_utils import require_torch, require_vision - - -if is_vision_available(): - import PIL - -if is_torch_available(): - import torch - - -@require_vision -class ProcessingUtilTester(unittest.TestCase): - def test_validate_images_text_input_order(self): - # text string and PIL images inputs - images = PIL.Image.new("RGB", (224, 224)) - text = "text" - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - - # text list of string and numpy images inputs - images = np.random.rand(224, 224, 3) - text = ["text1", "text2"] - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertTrue(np.array_equal(valid_images, images)) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertTrue(np.array_equal(valid_images, images)) - self.assertEqual(valid_text, text) - - # text nested list of string and list of pil images inputs - images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))] - text = [["text1", "text2, text3"], ["text3", "text4"]] - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - - # list of strings and list of numpy images inputs - images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)] - text = ["text1", "text2"] - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertTrue(np.array_equal(valid_images[0], images[0])) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertTrue(np.array_equal(valid_images[0], images[0])) - self.assertEqual(valid_text, text) - - # list of strings and list of url images inputs - images = ["https://url1", "https://url2"] - text = ["text1", "text2"] - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - - # list of strings and nested list of numpy images inputs - images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] - text = ["text1", "text2"] - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) - self.assertEqual(valid_text, text) - - # nested list of strings and nested list of PIL images inputs - images = [ - [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))], - [PIL.Image.new("RGB", (224, 224))], - ] - text = [["text1", "text2, text3"], ["text3", "text4"]] - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertEqual(valid_images, images) - self.assertEqual(valid_text, text) - - # None images - images = None - text = "text" - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertEqual(images, None) - self.assertEqual(text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertEqual(images, None) - self.assertEqual(text, text) - - # None text - images = PIL.Image.new("RGB", (224, 224)) - text = None - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertEqual(images, images) - self.assertEqual(text, None) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertEqual(images, images) - self.assertEqual(text, None) - - # incorrect inputs - images = "text" - text = "text" - with self.assertRaises(ValueError): - _validate_images_text_input_order(images=images, text=text) - - @require_torch - def test_validate_images_text_input_order_torch(self): - # text string and torch images inputs - images = torch.rand(224, 224, 3) - text = "text" - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertTrue(torch.equal(valid_images, images)) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertTrue(torch.equal(valid_images, images)) - self.assertEqual(valid_text, text) - - # text list of string and list of torch images inputs - images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)] - text = ["text1", "text2"] - # test correct text and images order - valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) - self.assertTrue(torch.equal(valid_images[0], images[0])) - self.assertEqual(valid_text, text) - # test incorrect text and images order - valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) - self.assertTrue(torch.equal(valid_images[0], images[0])) - self.assertEqual(valid_text, text)