Skip to content

More flexible output_shape computation in keras.layers.MultiHeadAttention #20503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,10 @@ def compute_output_shape(
)

if self._output_shape:
return query_shape[:-1] + self._output_shape
if isinstance(self._output_shape, tuple):
return query_shape[:-1] + self._output_shape
else:
return query_shape[:-1] + (self._output_shape,)

return query_shape

Expand Down
25 changes: 25 additions & 0 deletions keras/src/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from keras.src import testing
from keras.src.layers.attention.attention import disable_flash_attention
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.multi_head_attention import MultiHeadAttention


class MultiHeadAttentionTest(testing.TestCase):
Expand Down Expand Up @@ -593,3 +594,27 @@ def test_flash_attention_numerical_correctness(self):
)

self.assertAllClose(output_with_flash, output_without_flash)




def test_multi_head_attention_output_shape_as_int():
"""Test MultiHeadAttention with output_shape as an int."""
mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8)
query = random.uniform((2, 4, 16))
value = random.uniform((2, 4, 16))
output = mha(query=query, value=value)

assert output.shape == (2, 4, 8), (f"Expected shape (2, 4, 8),"
f" got {output.shape}")


def test_multi_head_attention_output_shape_as_tuple():
"""Test MultiHeadAttention with output_shape as a tuple."""
mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=(8, 8))
query = random.uniform((2, 4, 16))
value = random.uniform((2, 4, 16))
output = mha(query=query, value=value)

assert output.shape == (2, 4, 8, 8), (f"Expected shape (2, 4, 8, 8),"
f" got {output.shape}")
Loading