18
18
from typing import Dict , Iterable , List , Optional , Set , Tuple
19
19
20
20
from synapse .api .errors import StoreError
21
+ from synapse .events import EventBase
21
22
from synapse .metrics .background_process_metrics import run_as_background_process
22
23
from synapse .storage ._base import SQLBaseStore , make_in_list_sql_clause
23
- from synapse .storage .database import DatabasePool
24
+ from synapse .storage .database import DatabasePool , LoggingTransaction
24
25
from synapse .storage .databases .main .events_worker import EventsWorkerStore
25
26
from synapse .storage .databases .main .signatures import SignatureWorkerStore
27
+ from synapse .types import Collection
26
28
from synapse .util .caches .descriptors import cached
27
29
from synapse .util .iterutils import batch_iter
28
30
29
31
logger = logging .getLogger (__name__ )
30
32
31
33
32
34
class EventFederationWorkerStore (EventsWorkerStore , SignatureWorkerStore , SQLBaseStore ):
33
- async def get_auth_chain (self , event_ids , include_given = False ):
35
+ async def get_auth_chain (
36
+ self , event_ids : Collection [str ], include_given : bool = False
37
+ ) -> List [EventBase ]:
34
38
"""Get auth events for given event_ids. The events *must* be state events.
35
39
36
40
Args:
37
- event_ids (list) : state events
38
- include_given (bool) : include the given events in result
41
+ event_ids: state events
42
+ include_given: include the given events in result
39
43
40
44
Returns:
41
45
list of events
42
46
"""
43
- event_ids = await self .get_auth_chain_ids (
44
- event_ids , include_given = include_given
47
+ # get_events_as_list requires a list, so convert to a list here
48
+ event_ids = list (
49
+ await self .get_auth_chain_ids (event_ids , include_given = include_given )
45
50
)
46
51
return await self .get_events_as_list (event_ids )
47
52
48
- def get_auth_chain_ids (
49
- self ,
50
- event_ids : List [str ],
51
- include_given : bool = False ,
52
- ignore_events : Optional [Set [str ]] = None ,
53
- ):
53
+ async def get_auth_chain_ids (
54
+ self , event_ids : Collection [str ], include_given : bool = False ,
55
+ ) -> Set [str ]:
54
56
"""Get auth events for given event_ids. The events *must* be state events.
55
57
56
58
Args:
57
59
event_ids: state events
58
60
include_given: include the given events in result
59
- ignore_events: Set of events to exclude from the returned auth
60
- chain. This is useful if the caller will just discard the
61
- given events anyway, and saves us from figuring out their auth
62
- chains if not required.
63
61
64
62
Returns:
65
- list of event_ids
63
+ set of event_ids
66
64
"""
67
- return self .db_pool .runInteraction (
65
+ return await self .db_pool .runInteraction (
68
66
"get_auth_chain_ids" ,
69
67
self ._get_auth_chain_ids_txn ,
70
68
event_ids ,
71
69
include_given ,
72
- ignore_events ,
73
70
)
74
71
75
- def _get_auth_chain_ids_txn (self , txn , event_ids , include_given , ignore_events ):
76
- if ignore_events is None :
77
- ignore_events = set ()
78
-
72
+ def _get_auth_chain_ids_txn (
73
+ self , txn : LoggingTransaction , event_ids : Collection [str ], include_given : bool
74
+ ) -> Set [str ]:
79
75
if include_given :
80
76
results = set (event_ids )
81
77
else :
82
78
results = set ()
83
79
84
- base_sql = "SELECT auth_id FROM event_auth WHERE "
80
+ base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
85
81
86
82
front = set (event_ids )
87
83
while front :
@@ -93,13 +89,12 @@ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
93
89
txn .execute (base_sql + clause , args )
94
90
new_front .update (r [0 ] for r in txn )
95
91
96
- new_front -= ignore_events
97
92
new_front -= results
98
93
99
94
front = new_front
100
95
results .update (front )
101
96
102
- return list ( results )
97
+ return results
103
98
104
99
def get_auth_chain_difference (self , state_sets : List [Set [str ]]):
105
100
"""Given sets of state events figure out the auth chain difference (as
0 commit comments