|
17 | 17 | from unittest import mock
|
18 | 18 |
|
19 | 19 | from twisted.internet import defer, reactor
|
| 20 | +from twisted.internet.defer import Deferred |
20 | 21 |
|
21 | 22 | from synapse.api.errors import SynapseError
|
22 | 23 | from synapse.logging.context import (
|
@@ -703,6 +704,48 @@ async def list_fn(self, args1, arg2):
|
703 | 704 | obj.mock.assert_called_once_with((40,), 2)
|
704 | 705 | self.assertEqual(r, {10: "fish", 40: "gravy"})
|
705 | 706 |
|
| 707 | + def test_concurrent_lookups(self): |
| 708 | + """All concurrent lookups should get the same result""" |
| 709 | + |
| 710 | + class Cls: |
| 711 | + def __init__(self): |
| 712 | + self.mock = mock.Mock() |
| 713 | + |
| 714 | + @descriptors.cached() |
| 715 | + def fn(self, arg1): |
| 716 | + pass |
| 717 | + |
| 718 | + @descriptors.cachedList("fn", "args1") |
| 719 | + def list_fn(self, args1) -> "Deferred[dict]": |
| 720 | + return self.mock(args1) |
| 721 | + |
| 722 | + obj = Cls() |
| 723 | + deferred_result = Deferred() |
| 724 | + obj.mock.return_value = deferred_result |
| 725 | + |
| 726 | + # start off several concurrent lookups of the same key |
| 727 | + d1 = obj.list_fn([10]) |
| 728 | + d2 = obj.list_fn([10]) |
| 729 | + d3 = obj.list_fn([10]) |
| 730 | + |
| 731 | + # the mock should have been called exactly once |
| 732 | + obj.mock.assert_called_once_with((10,)) |
| 733 | + obj.mock.reset_mock() |
| 734 | + |
| 735 | + # ... and none of the calls should yet be complete |
| 736 | + self.assertFalse(d1.called) |
| 737 | + self.assertFalse(d2.called) |
| 738 | + self.assertFalse(d3.called) |
| 739 | + |
| 740 | + # complete the lookup. @cachedList functions need to complete with a map |
| 741 | + # of input->result |
| 742 | + deferred_result.callback({10: "peas"}) |
| 743 | + |
| 744 | + # ... which should give the right result to all the callers |
| 745 | + self.assertEqual(self.successResultOf(d1), {10: "peas"}) |
| 746 | + self.assertEqual(self.successResultOf(d2), {10: "peas"}) |
| 747 | + self.assertEqual(self.successResultOf(d3), {10: "peas"}) |
| 748 | + |
706 | 749 | @defer.inlineCallbacks
|
707 | 750 | def test_invalidate(self):
|
708 | 751 | """Make sure that invalidation callbacks are called."""
|
|
0 commit comments