14
14
# limitations under the License.
15
15
16
16
import logging
17
- from typing import TYPE_CHECKING , Collection , Dict , List , Optional , Set , Tuple , cast
17
+ from typing import (
18
+ TYPE_CHECKING ,
19
+ Collection ,
20
+ Dict ,
21
+ Iterable ,
22
+ List ,
23
+ Optional ,
24
+ Set ,
25
+ Tuple ,
26
+ cast ,
27
+ )
18
28
19
29
from synapse .logging import issue9533_logger
20
30
from synapse .logging .opentracing import log_kv , set_tag , trace
@@ -118,7 +128,13 @@ def __init__(
118
128
prefilled_cache = device_outbox_prefill ,
119
129
)
120
130
121
- def process_replication_rows (self , stream_name , instance_name , token , rows ):
131
+ def process_replication_rows (
132
+ self ,
133
+ stream_name : str ,
134
+ instance_name : str ,
135
+ token : int ,
136
+ rows : Iterable [ToDeviceStream .ToDeviceStreamRow ],
137
+ ) -> None :
122
138
if stream_name == ToDeviceStream .NAME :
123
139
# If replication is happening than postgres must be being used.
124
140
assert isinstance (self ._device_inbox_id_gen , MultiWriterIdGenerator )
@@ -134,7 +150,7 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
134
150
)
135
151
return super ().process_replication_rows (stream_name , instance_name , token , rows )
136
152
137
- def get_to_device_stream_token (self ):
153
+ def get_to_device_stream_token (self ) -> int :
138
154
return self ._device_inbox_id_gen .get_current_token ()
139
155
140
156
async def get_messages_for_user_devices (
@@ -301,7 +317,9 @@ async def _get_device_messages(
301
317
if not user_ids_to_query :
302
318
return {}, to_stream_id
303
319
304
- def get_device_messages_txn (txn : LoggingTransaction ):
320
+ def get_device_messages_txn (
321
+ txn : LoggingTransaction ,
322
+ ) -> Tuple [Dict [Tuple [str , str ], List [JsonDict ]], int ]:
305
323
# Build a query to select messages from any of the given devices that
306
324
# are between the given stream id bounds.
307
325
@@ -428,7 +446,7 @@ async def delete_messages_for_device(
428
446
log_kv ({"message" : "No changes in cache since last check" })
429
447
return 0
430
448
431
- def delete_messages_for_device_txn (txn ) :
449
+ def delete_messages_for_device_txn (txn : LoggingTransaction ) -> int :
432
450
sql = (
433
451
"DELETE FROM device_inbox"
434
452
" WHERE user_id = ? AND device_id = ?"
@@ -455,15 +473,14 @@ def delete_messages_for_device_txn(txn):
455
473
456
474
@trace
457
475
async def get_new_device_msgs_for_remote (
458
- self , destination , last_stream_id , current_stream_id , limit
459
- ) -> Tuple [List [dict ], int ]:
476
+ self , destination : str , last_stream_id : int , current_stream_id : int , limit : int
477
+ ) -> Tuple [List [JsonDict ], int ]:
460
478
"""
461
479
Args:
462
- destination(str) : The name of the remote server.
463
- last_stream_id(int|long) : The last position of the device message stream
480
+ destination: The name of the remote server.
481
+ last_stream_id: The last position of the device message stream
464
482
that the server sent up to.
465
- current_stream_id(int|long): The current position of the device
466
- message stream.
483
+ current_stream_id: The current position of the device message stream.
467
484
Returns:
468
485
A list of messages for the device and where in the stream the messages got to.
469
486
"""
@@ -485,7 +502,9 @@ async def get_new_device_msgs_for_remote(
485
502
return [], last_stream_id
486
503
487
504
@trace
488
- def get_new_messages_for_remote_destination_txn (txn ):
505
+ def get_new_messages_for_remote_destination_txn (
506
+ txn : LoggingTransaction ,
507
+ ) -> Tuple [List [JsonDict ], int ]:
489
508
sql = (
490
509
"SELECT stream_id, messages_json FROM device_federation_outbox"
491
510
" WHERE destination = ?"
@@ -527,7 +546,7 @@ async def delete_device_msgs_for_remote(
527
546
up_to_stream_id: Where to delete messages up to.
528
547
"""
529
548
530
- def delete_messages_for_remote_destination_txn (txn ) :
549
+ def delete_messages_for_remote_destination_txn (txn : LoggingTransaction ) -> None :
531
550
sql = (
532
551
"DELETE FROM device_federation_outbox"
533
552
" WHERE destination = ?"
@@ -566,7 +585,9 @@ async def get_all_new_device_messages(
566
585
if last_id == current_id :
567
586
return [], current_id , False
568
587
569
- def get_all_new_device_messages_txn (txn ):
588
+ def get_all_new_device_messages_txn (
589
+ txn : LoggingTransaction ,
590
+ ) -> Tuple [List [Tuple [int , tuple ]], int , bool ]:
570
591
# We limit like this as we might have multiple rows per stream_id, and
571
592
# we want to make sure we always get all entries for any stream_id
572
593
# we return.
@@ -607,8 +628,8 @@ def get_all_new_device_messages_txn(txn):
607
628
@trace
608
629
async def add_messages_to_device_inbox (
609
630
self ,
610
- local_messages_by_user_then_device : dict ,
611
- remote_messages_by_destination : dict ,
631
+ local_messages_by_user_then_device : Dict [ str , Dict [ str , JsonDict ]] ,
632
+ remote_messages_by_destination : Dict [ str , JsonDict ] ,
612
633
) -> int :
613
634
"""Used to send messages from this server.
614
635
@@ -624,7 +645,9 @@ async def add_messages_to_device_inbox(
624
645
625
646
assert self ._can_write_to_device
626
647
627
- def add_messages_txn (txn , now_ms , stream_id ):
648
+ def add_messages_txn (
649
+ txn : LoggingTransaction , now_ms : int , stream_id : int
650
+ ) -> None :
628
651
# Add the local messages directly to the local inbox.
629
652
self ._add_messages_to_local_device_inbox_txn (
630
653
txn , stream_id , local_messages_by_user_then_device
@@ -677,11 +700,16 @@ def add_messages_txn(txn, now_ms, stream_id):
677
700
return self ._device_inbox_id_gen .get_current_token ()
678
701
679
702
async def add_messages_from_remote_to_device_inbox (
680
- self , origin : str , message_id : str , local_messages_by_user_then_device : dict
703
+ self ,
704
+ origin : str ,
705
+ message_id : str ,
706
+ local_messages_by_user_then_device : Dict [str , Dict [str , JsonDict ]],
681
707
) -> int :
682
708
assert self ._can_write_to_device
683
709
684
- def add_messages_txn (txn , now_ms , stream_id ):
710
+ def add_messages_txn (
711
+ txn : LoggingTransaction , now_ms : int , stream_id : int
712
+ ) -> None :
685
713
# Check if we've already inserted a matching message_id for that
686
714
# origin. This can happen if the origin doesn't receive our
687
715
# acknowledgement from the first time we received the message.
@@ -727,8 +755,11 @@ def add_messages_txn(txn, now_ms, stream_id):
727
755
return stream_id
728
756
729
757
def _add_messages_to_local_device_inbox_txn (
730
- self , txn , stream_id , messages_by_user_then_device
731
- ):
758
+ self ,
759
+ txn : LoggingTransaction ,
760
+ stream_id : int ,
761
+ messages_by_user_then_device : Dict [str , Dict [str , JsonDict ]],
762
+ ) -> None :
732
763
assert self ._can_write_to_device
733
764
734
765
local_by_user_then_device = {}
@@ -840,8 +871,10 @@ def __init__(
840
871
self ._remove_dead_devices_from_device_inbox ,
841
872
)
842
873
843
- async def _background_drop_index_device_inbox (self , progress , batch_size ):
844
- def reindex_txn (conn ):
874
+ async def _background_drop_index_device_inbox (
875
+ self , progress : JsonDict , batch_size : int
876
+ ) -> int :
877
+ def reindex_txn (conn : LoggingDatabaseConnection ) -> None :
845
878
txn = conn .cursor ()
846
879
txn .execute ("DROP INDEX IF EXISTS device_inbox_stream_id" )
847
880
txn .close ()
0 commit comments