Skip to content

Commit 8a79442

Browse files
authored
More flexible output_shape computation in keras.layers.MultiHeadAttention (#20503)
* Made the compute_output_shape method more flexible; now _output_shape can be either an integer or a tuple (as previously required). Fix discussed in #19769 * Added unit test * Minor changes to comments in unit test * Minor changes to comments in unit test
1 parent 70b7044 commit 8a79442

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

keras/src/layers/attention/multi_head_attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,10 @@ def compute_output_shape(
669669
)
670670

671671
if self._output_shape:
672-
return query_shape[:-1] + self._output_shape
672+
if isinstance(self._output_shape, tuple):
673+
return query_shape[:-1] + self._output_shape
674+
else:
675+
return query_shape[:-1] + (self._output_shape,)
673676

674677
return query_shape
675678

keras/src/layers/attention/multi_head_attention_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras.src import testing
1717
from keras.src.layers.attention.attention import disable_flash_attention
1818
from keras.src.layers.attention.attention import enable_flash_attention
19+
from keras.src.layers.attention.multi_head_attention import MultiHeadAttention
1920

2021

2122
class MultiHeadAttentionTest(testing.TestCase):
@@ -593,3 +594,27 @@ def test_flash_attention_numerical_correctness(self):
593594
)
594595

595596
self.assertAllClose(output_with_flash, output_without_flash)
597+
598+
599+
600+
601+
def test_multi_head_attention_output_shape_as_int():
602+
"""Test MultiHeadAttention with output_shape as an int."""
603+
mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8)
604+
query = random.uniform((2, 4, 16))
605+
value = random.uniform((2, 4, 16))
606+
output = mha(query=query, value=value)
607+
608+
assert output.shape == (2, 4, 8), (f"Expected shape (2, 4, 8),"
609+
f" got {output.shape}")
610+
611+
612+
def test_multi_head_attention_output_shape_as_tuple():
613+
"""Test MultiHeadAttention with output_shape as a tuple."""
614+
mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=(8, 8))
615+
query = random.uniform((2, 4, 16))
616+
value = random.uniform((2, 4, 16))
617+
output = mha(query=query, value=value)
618+
619+
assert output.shape == (2, 4, 8, 8), (f"Expected shape (2, 4, 8, 8),"
620+
f" got {output.shape}")

0 commit comments

Comments
 (0)