Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit d00f7af

Browse files
committed
Micro-optimisations to get_auth_chain_ids
1 parent c9c544c commit d00f7af

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

changelog.d/8132.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Micro-optimisations to get_auth_chain_ids.

synapse/storage/databases/main/event_federation.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,70 +18,66 @@
1818
from typing import Dict, Iterable, List, Optional, Set, Tuple
1919

2020
from synapse.api.errors import StoreError
21+
from synapse.events import EventBase
2122
from synapse.metrics.background_process_metrics import run_as_background_process
2223
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
23-
from synapse.storage.database import DatabasePool
24+
from synapse.storage.database import DatabasePool, LoggingTransaction
2425
from synapse.storage.databases.main.events_worker import EventsWorkerStore
2526
from synapse.storage.databases.main.signatures import SignatureWorkerStore
27+
from synapse.types import Collection
2628
from synapse.util.caches.descriptors import cached
2729
from synapse.util.iterutils import batch_iter
2830

2931
logger = logging.getLogger(__name__)
3032

3133

3234
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
33-
async def get_auth_chain(self, event_ids, include_given=False):
35+
async def get_auth_chain(
36+
self, event_ids: Collection[str], include_given: bool = False
37+
) -> List[EventBase]:
3438
"""Get auth events for given event_ids. The events *must* be state events.
3539
3640
Args:
37-
event_ids (list): state events
38-
include_given (bool): include the given events in result
41+
event_ids: state events
42+
include_given: include the given events in result
3943
4044
Returns:
4145
list of events
4246
"""
43-
event_ids = await self.get_auth_chain_ids(
44-
event_ids, include_given=include_given
47+
# get_events_as_list requires a list, so convert to a list here
48+
event_ids = list(
49+
await self.get_auth_chain_ids(event_ids, include_given=include_given)
4550
)
4651
return await self.get_events_as_list(event_ids)
4752

48-
def get_auth_chain_ids(
49-
self,
50-
event_ids: List[str],
51-
include_given: bool = False,
52-
ignore_events: Optional[Set[str]] = None,
53-
):
53+
async def get_auth_chain_ids(
54+
self, event_ids: Collection[str], include_given: bool = False,
55+
) -> Set[str]:
5456
"""Get auth events for given event_ids. The events *must* be state events.
5557
5658
Args:
5759
event_ids: state events
5860
include_given: include the given events in result
59-
ignore_events: Set of events to exclude from the returned auth
60-
chain. This is useful if the caller will just discard the
61-
given events anyway, and saves us from figuring out their auth
62-
chains if not required.
6361
6462
Returns:
65-
list of event_ids
63+
set of event_ids
6664
"""
67-
return self.db_pool.runInteraction(
65+
return await self.db_pool.runInteraction(
6866
"get_auth_chain_ids",
6967
self._get_auth_chain_ids_txn,
7068
event_ids,
7169
include_given,
72-
ignore_events,
7370
)
7471

75-
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
76-
if ignore_events is None:
77-
ignore_events = set()
78-
72+
def _get_auth_chain_ids_txn(
73+
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
74+
) -> Set[str]:
7975
if include_given:
8076
results = set(event_ids)
8177
else:
8278
results = set()
8379

84-
base_sql = "SELECT auth_id FROM event_auth WHERE "
80+
base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
8581

8682
front = set(event_ids)
8783
while front:
@@ -93,13 +89,12 @@ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
9389
txn.execute(base_sql + clause, args)
9490
new_front.update(r[0] for r in txn)
9591

96-
new_front -= ignore_events
9792
new_front -= results
9893

9994
front = new_front
10095
results.update(front)
10196

102-
return list(results)
97+
return results
10398

10499
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
105100
"""Given sets of state events figure out the auth chain difference (as

0 commit comments

Comments
 (0)