Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit b76f1a4

Browse files
authored
Add some type hints to datastore (#12485)
1 parent 63ba9ba commit b76f1a4

File tree

12 files changed

+188
-84
lines changed

12 files changed

+188
-84
lines changed

changelog.d/12485.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add some type hints to datastore.

synapse/storage/databases/main/__init__.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
# limitations under the License.
1616

1717
import logging
18-
from typing import TYPE_CHECKING, List, Optional, Tuple
18+
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
1919

2020
from synapse.config.homeserver import HomeServerConfig
21-
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
21+
from synapse.storage.database import (
22+
DatabasePool,
23+
LoggingDatabaseConnection,
24+
LoggingTransaction,
25+
)
2226
from synapse.storage.databases.main.stats import UserSortOrder
23-
from synapse.storage.engines import PostgresEngine
27+
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
28+
from synapse.storage.types import Cursor
2429
from synapse.storage.util.id_generators import (
2530
IdGenerator,
2631
MultiWriterIdGenerator,
@@ -266,7 +271,9 @@ async def get_users_paginate(
266271
A tuple of a list of mappings from user to information and a count of total users.
267272
"""
268273

269-
def get_users_paginate_txn(txn):
274+
def get_users_paginate_txn(
275+
txn: LoggingTransaction,
276+
) -> Tuple[List[JsonDict], int]:
270277
filters = []
271278
args = [self.hs.config.server.server_name]
272279

@@ -301,7 +308,7 @@ def get_users_paginate_txn(txn):
301308
"""
302309
sql = "SELECT COUNT(*) as total_users " + sql_base
303310
txn.execute(sql, args)
304-
count = txn.fetchone()[0]
311+
count = cast(Tuple[int], txn.fetchone())[0]
305312

306313
sql = f"""
307314
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
@@ -338,7 +345,9 @@ async def search_users(self, term: str) -> Optional[List[JsonDict]]:
338345
)
339346

340347

341-
def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
348+
def check_database_before_upgrade(
349+
cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
350+
) -> None:
342351
"""Called before upgrading an existing database to check that it is broadly sane
343352
compared with the configuration.
344353
"""

synapse/storage/databases/main/appservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import logging
1616
import re
17-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
17+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast
1818

1919
from synapse.appservice import (
2020
ApplicationService,
@@ -83,7 +83,7 @@ def get_max_as_txn_id(txn: Cursor) -> int:
8383
txn.execute(
8484
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
8585
)
86-
return txn.fetchone()[0] # type: ignore
86+
return cast(Tuple[int], txn.fetchone())[0]
8787

8888
self._as_txn_seq_gen = build_sequence_generator(
8989
db_conn,

synapse/storage/databases/main/deviceinbox.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
17+
from typing import (
18+
TYPE_CHECKING,
19+
Collection,
20+
Dict,
21+
Iterable,
22+
List,
23+
Optional,
24+
Set,
25+
Tuple,
26+
cast,
27+
)
1828

1929
from synapse.logging import issue9533_logger
2030
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -118,7 +128,13 @@ def __init__(
118128
prefilled_cache=device_outbox_prefill,
119129
)
120130

121-
def process_replication_rows(self, stream_name, instance_name, token, rows):
131+
def process_replication_rows(
132+
self,
133+
stream_name: str,
134+
instance_name: str,
135+
token: int,
136+
rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
137+
) -> None:
122138
if stream_name == ToDeviceStream.NAME:
123139
# If replication is happening than postgres must be being used.
124140
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@@ -134,7 +150,7 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
134150
)
135151
return super().process_replication_rows(stream_name, instance_name, token, rows)
136152

137-
def get_to_device_stream_token(self):
153+
def get_to_device_stream_token(self) -> int:
138154
return self._device_inbox_id_gen.get_current_token()
139155

140156
async def get_messages_for_user_devices(
@@ -301,7 +317,9 @@ async def _get_device_messages(
301317
if not user_ids_to_query:
302318
return {}, to_stream_id
303319

304-
def get_device_messages_txn(txn: LoggingTransaction):
320+
def get_device_messages_txn(
321+
txn: LoggingTransaction,
322+
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
305323
# Build a query to select messages from any of the given devices that
306324
# are between the given stream id bounds.
307325

@@ -428,7 +446,7 @@ async def delete_messages_for_device(
428446
log_kv({"message": "No changes in cache since last check"})
429447
return 0
430448

431-
def delete_messages_for_device_txn(txn):
449+
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
432450
sql = (
433451
"DELETE FROM device_inbox"
434452
" WHERE user_id = ? AND device_id = ?"
@@ -455,15 +473,14 @@ def delete_messages_for_device_txn(txn):
455473

456474
@trace
457475
async def get_new_device_msgs_for_remote(
458-
self, destination, last_stream_id, current_stream_id, limit
459-
) -> Tuple[List[dict], int]:
476+
self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
477+
) -> Tuple[List[JsonDict], int]:
460478
"""
461479
Args:
462-
destination(str): The name of the remote server.
463-
last_stream_id(int|long): The last position of the device message stream
480+
destination: The name of the remote server.
481+
last_stream_id: The last position of the device message stream
464482
that the server sent up to.
465-
current_stream_id(int|long): The current position of the device
466-
message stream.
483+
current_stream_id: The current position of the device message stream.
467484
Returns:
468485
A list of messages for the device and where in the stream the messages got to.
469486
"""
@@ -485,7 +502,9 @@ async def get_new_device_msgs_for_remote(
485502
return [], last_stream_id
486503

487504
@trace
488-
def get_new_messages_for_remote_destination_txn(txn):
505+
def get_new_messages_for_remote_destination_txn(
506+
txn: LoggingTransaction,
507+
) -> Tuple[List[JsonDict], int]:
489508
sql = (
490509
"SELECT stream_id, messages_json FROM device_federation_outbox"
491510
" WHERE destination = ?"
@@ -527,7 +546,7 @@ async def delete_device_msgs_for_remote(
527546
up_to_stream_id: Where to delete messages up to.
528547
"""
529548

530-
def delete_messages_for_remote_destination_txn(txn):
549+
def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
531550
sql = (
532551
"DELETE FROM device_federation_outbox"
533552
" WHERE destination = ?"
@@ -566,7 +585,9 @@ async def get_all_new_device_messages(
566585
if last_id == current_id:
567586
return [], current_id, False
568587

569-
def get_all_new_device_messages_txn(txn):
588+
def get_all_new_device_messages_txn(
589+
txn: LoggingTransaction,
590+
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
570591
# We limit like this as we might have multiple rows per stream_id, and
571592
# we want to make sure we always get all entries for any stream_id
572593
# we return.
@@ -607,8 +628,8 @@ def get_all_new_device_messages_txn(txn):
607628
@trace
608629
async def add_messages_to_device_inbox(
609630
self,
610-
local_messages_by_user_then_device: dict,
611-
remote_messages_by_destination: dict,
631+
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
632+
remote_messages_by_destination: Dict[str, JsonDict],
612633
) -> int:
613634
"""Used to send messages from this server.
614635
@@ -624,7 +645,9 @@ async def add_messages_to_device_inbox(
624645

625646
assert self._can_write_to_device
626647

627-
def add_messages_txn(txn, now_ms, stream_id):
648+
def add_messages_txn(
649+
txn: LoggingTransaction, now_ms: int, stream_id: int
650+
) -> None:
628651
# Add the local messages directly to the local inbox.
629652
self._add_messages_to_local_device_inbox_txn(
630653
txn, stream_id, local_messages_by_user_then_device
@@ -677,11 +700,16 @@ def add_messages_txn(txn, now_ms, stream_id):
677700
return self._device_inbox_id_gen.get_current_token()
678701

679702
async def add_messages_from_remote_to_device_inbox(
680-
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
703+
self,
704+
origin: str,
705+
message_id: str,
706+
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
681707
) -> int:
682708
assert self._can_write_to_device
683709

684-
def add_messages_txn(txn, now_ms, stream_id):
710+
def add_messages_txn(
711+
txn: LoggingTransaction, now_ms: int, stream_id: int
712+
) -> None:
685713
# Check if we've already inserted a matching message_id for that
686714
# origin. This can happen if the origin doesn't receive our
687715
# acknowledgement from the first time we received the message.
@@ -727,8 +755,11 @@ def add_messages_txn(txn, now_ms, stream_id):
727755
return stream_id
728756

729757
def _add_messages_to_local_device_inbox_txn(
730-
self, txn, stream_id, messages_by_user_then_device
731-
):
758+
self,
759+
txn: LoggingTransaction,
760+
stream_id: int,
761+
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
762+
) -> None:
732763
assert self._can_write_to_device
733764

734765
local_by_user_then_device = {}
@@ -840,8 +871,10 @@ def __init__(
840871
self._remove_dead_devices_from_device_inbox,
841872
)
842873

843-
async def _background_drop_index_device_inbox(self, progress, batch_size):
844-
def reindex_txn(conn):
874+
async def _background_drop_index_device_inbox(
875+
self, progress: JsonDict, batch_size: int
876+
) -> int:
877+
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
845878
txn = conn.cursor()
846879
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
847880
txn.close()

0 commit comments

Comments
 (0)