|
16 | 16 | from keras.src import testing
|
17 | 17 | from keras.src.layers.attention.attention import disable_flash_attention
|
18 | 18 | from keras.src.layers.attention.attention import enable_flash_attention
|
| 19 | +from keras.src.layers.attention.multi_head_attention import MultiHeadAttention |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class MultiHeadAttentionTest(testing.TestCase):
|
@@ -593,3 +594,27 @@ def test_flash_attention_numerical_correctness(self):
|
593 | 594 | )
|
594 | 595 |
|
595 | 596 | self.assertAllClose(output_with_flash, output_without_flash)
|
| 597 | + |
| 598 | + |
| 599 | + |
| 600 | + |
| 601 | +def test_multi_head_attention_output_shape_as_int(): |
| 602 | + """Test MultiHeadAttention with output_shape as an int.""" |
| 603 | + mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8) |
| 604 | + query = random.uniform((2, 4, 16)) |
| 605 | + value = random.uniform((2, 4, 16)) |
| 606 | + output = mha(query=query, value=value) |
| 607 | + |
| 608 | + assert output.shape == (2, 4, 8), (f"Expected shape (2, 4, 8)," |
| 609 | + f" got {output.shape}") |
| 610 | + |
| 611 | + |
| 612 | +def test_multi_head_attention_output_shape_as_tuple(): |
| 613 | + """Test MultiHeadAttention with output_shape as a tuple.""" |
| 614 | + mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=(8, 8)) |
| 615 | + query = random.uniform((2, 4, 16)) |
| 616 | + value = random.uniform((2, 4, 16)) |
| 617 | + output = mha(query=query, value=value) |
| 618 | + |
| 619 | + assert output.shape == (2, 4, 8, 8), (f"Expected shape (2, 4, 8, 8)," |
| 620 | + f" got {output.shape}") |
0 commit comments