diff --git a/keras/src/layers/attention/multi_head_attention.py b/keras/src/layers/attention/multi_head_attention.py index 6f23b895cb10..dfac5c4f8e02 100644 --- a/keras/src/layers/attention/multi_head_attention.py +++ b/keras/src/layers/attention/multi_head_attention.py @@ -669,7 +669,10 @@ def compute_output_shape( ) if self._output_shape: - return query_shape[:-1] + self._output_shape + if isinstance(self._output_shape, tuple): + return query_shape[:-1] + self._output_shape + else: + return query_shape[:-1] + (self._output_shape,) return query_shape diff --git a/keras/src/layers/attention/multi_head_attention_test.py b/keras/src/layers/attention/multi_head_attention_test.py index 4707cb893dd5..67e2e35d3c4f 100644 --- a/keras/src/layers/attention/multi_head_attention_test.py +++ b/keras/src/layers/attention/multi_head_attention_test.py @@ -16,6 +16,7 @@ from keras.src import testing from keras.src.layers.attention.attention import disable_flash_attention from keras.src.layers.attention.attention import enable_flash_attention +from keras.src.layers.attention.multi_head_attention import MultiHeadAttention class MultiHeadAttentionTest(testing.TestCase): @@ -593,3 +594,27 @@ def test_flash_attention_numerical_correctness(self): ) self.assertAllClose(output_with_flash, output_without_flash) + + + + +def test_multi_head_attention_output_shape_as_int(): + """Test MultiHeadAttention with output_shape as an int.""" + mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8) + query = random.uniform((2, 4, 16)) + value = random.uniform((2, 4, 16)) + output = mha(query=query, value=value) + + assert output.shape == (2, 4, 8), (f"Expected shape (2, 4, 8)," + f" got {output.shape}") + + +def test_multi_head_attention_output_shape_as_tuple(): + """Test MultiHeadAttention with output_shape as a tuple.""" + mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=(8, 8)) + query = random.uniform((2, 4, 16)) + value = random.uniform((2, 4, 16)) + output = mha(query=query, value=value) + + assert output.shape == (2, 4, 8, 8), (f"Expected shape (2, 4, 8, 8)," + f" got {output.shape}")