Skip to content

Add serialization function for StaticCache #38879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from packaging import version

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're moving all these flags to is_torch_greater_or_equal, which is already imported here
e.g. is_torch_greater_or_equal("2.7.0"), or is_torch_greater_or_equal("2.7.0", accept_dev=True) if you also want to accept dev versions of 2.7

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw this snippet of code. Do you want me to remove the flag I inserted and use is_torch_greater_or_equal("2.7", accept_dev=True) directly in the code?

is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)  # the line I added
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)


from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
Expand Down Expand Up @@ -651,9 +651,7 @@ def batch_select_indices(self, indices: torch.Tensor):


# Utilities for `DynamicCache` <> torch.export support
def _flatten_dynamic_cache(
dynamic_cache: DynamicCache,
):
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(dynamic_cache, DynamicCache):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
Expand Down Expand Up @@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
def _unflatten_dynamic_cache(
values,
context: torch.utils._pytree.Context,
):
) -> DynamicCache:
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
for k, v in dictionary.items():
Expand Down Expand Up @@ -1230,6 +1228,69 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[
return kv_length, 0


# Utilities for `StaticCache` <> torch.export support
def _flatten_static_cache(cache: StaticCache):
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(cache, StaticCache):
raise RuntimeError("This pytree flattening function should only be applied to StaticCache")

if not is_torch_greater_or_equal_than_2_7:
logger.warning_once(
"StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
)

dictionary = {
"key_cache": getattr(cache, "key_cache"),
"value_cache": getattr(cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten(dictionary)


def _flatten_with_keys_static_cache(cache: StaticCache):
dictionary = {
"key_cache": getattr(cache, "key_cache"),
"value_cache": getattr(cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten_with_keys(dictionary)


def _make_static_cache(key_value_pairs):
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)

cache = StaticCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
max_cache_len=key_value_pairs[0][0].shape[2],
)
for i in range(len(key_value_pairs)):
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
return cache


def _unflatten_static_cache(
values,
context: torch.utils._pytree.Context,
) -> StaticCache:
return _make_static_cache(list(zip(values[0], values[1])))


if is_torch_greater_or_equal("2.7"):
torch.utils._pytree.register_pytree_node(
StaticCache,
_flatten_static_cache,
_unflatten_static_cache,
serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_static_cache,
)


class SlidingWindowCache(StaticCache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

logger = logging.get_logger(__name__)

is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
Expand Down
35 changes: 35 additions & 0 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,39 @@ def _random_kvs(config):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))

def test_unflatten_flatten_static_cache(self):
def make_static_cache(key_value_pairs):
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)

cache = StaticCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
max_cache_len=key_value_pairs[0][0].shape[2],
)
for i in range(len(key_value_pairs)):
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
return cache

cache = make_static_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
)
flat, spec = torch.utils._pytree.tree_flatten(cache)
self.assertIsInstance(flat, list)
self.assertEqual(len(flat), 6)
cache2 = torch.utils._pytree.tree_unflatten(flat, spec)
self.assertTrue(isinstance(cache2, StaticCache))


def _skip_on_failed_cache_prerequisites(test, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
Expand Down Expand Up @@ -726,6 +759,8 @@ def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
# TODO: Another test should be implemented to follow the same pattern
# as the one implemented for DynamicCache.
if not is_torch_greater_or_equal("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

Expand Down