Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit e8bce89

Browse files
authored
Aggregate unread notif count query for badge count calculation (#14255)
Fetch the unread notification counts used by the badge counts in push notifications for all rooms at once (instead of fetching them per room).
1 parent 4569eda commit e8bce89

File tree

4 files changed

+198
-27
lines changed

4 files changed

+198
-27
lines changed

changelog.d/14255.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Optimise push badge count calculations. Contributed by Nick @ Beeper (@fizzadar).

synapse/push/push_tools.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
1818
from synapse.storage.controllers import StorageControllers
1919
from synapse.storage.databases.main import DataStore
20-
from synapse.util.async_helpers import concurrently_execute
2120

2221

2322
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
2625

2726
badge = len(invites)
2827

29-
room_notifs = []
30-
31-
async def get_room_unread_count(room_id: str) -> None:
32-
room_notifs.append(
33-
await store.get_unread_event_push_actions_by_room_for_user(
34-
room_id,
35-
user_id,
36-
)
37-
)
38-
39-
await concurrently_execute(get_room_unread_count, joins, 10)
40-
41-
for notifs in room_notifs:
42-
# Combine the counts from all the threads.
43-
notify_count = notifs.main_timeline.notify_count + sum(
44-
n.notify_count for n in notifs.threads.values()
45-
)
28+
room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
29+
for room_id, notify_count in room_to_count.items():
30+
# room_to_count may include rooms which the user has left,
31+
# ignore those.
32+
if room_id not in joins:
33+
continue
4634

4735
if notify_count == 0:
4836
continue
@@ -51,8 +39,10 @@ async def get_room_unread_count(room_id: str) -> None:
5139
# return one badge count per conversation
5240
badge += 1
5341
else:
54-
# increment the badge count by the number of unread messages in the room
42+
# Increase badge by number of notifications in room
43+
# NOTE: this includes threaded and unthreaded notifications.
5544
badge += notify_count
45+
5646
return badge
5747

5848

synapse/storage/databases/main/event_push_actions.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"""
7575

7676
import logging
77+
from collections import defaultdict
7778
from typing import (
7879
TYPE_CHECKING,
7980
Collection,
@@ -95,6 +96,7 @@
9596
DatabasePool,
9697
LoggingDatabaseConnection,
9798
LoggingTransaction,
99+
PostgresEngine,
98100
)
99101
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
100102
from synapse.storage.databases.main.stream import StreamWorkerStore
@@ -463,6 +465,153 @@ def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
463465

464466
return result
465467

468+
async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
469+
"""Get the notification count by room for a user. Only considers notifications,
470+
not highlight or unread counts, and threads are currently aggregated under their room.
471+
472+
This function is intentionally not cached because it is called to calculate the
473+
unread badge for push notifications and thus the result is expected to change.
474+
475+
Note that this function assumes the user is a member of the room. Because
476+
summary rows are not removed when a user leaves a room, the caller must
477+
filter out those results from the result.
478+
479+
Returns:
480+
A map of room ID to notification counts for the given user.
481+
"""
482+
return await self.db_pool.runInteraction(
483+
"get_unread_counts_by_room_for_user",
484+
self._get_unread_counts_by_room_for_user_txn,
485+
user_id,
486+
)
487+
488+
def _get_unread_counts_by_room_for_user_txn(
489+
self, txn: LoggingTransaction, user_id: str
490+
) -> Dict[str, int]:
491+
receipt_types_clause, args = make_in_list_sql_clause(
492+
self.database_engine,
493+
"receipt_type",
494+
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
495+
)
496+
args.extend([user_id, user_id])
497+
498+
receipts_cte = f"""
499+
WITH all_receipts AS (
500+
SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
501+
FROM receipts_linearized
502+
LEFT JOIN events USING (room_id, event_id)
503+
WHERE
504+
{receipt_types_clause}
505+
AND user_id = ?
506+
GROUP BY room_id, thread_id
507+
)
508+
"""
509+
510+
receipts_joins = """
511+
LEFT JOIN (
512+
SELECT room_id, thread_id,
513+
max_receipt_stream_ordering AS threaded_receipt_stream_ordering
514+
FROM all_receipts
515+
WHERE thread_id IS NOT NULL
516+
) AS threaded_receipts USING (room_id, thread_id)
517+
LEFT JOIN (
518+
SELECT room_id, thread_id,
519+
max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
520+
FROM all_receipts
521+
WHERE thread_id IS NULL
522+
) AS unthreaded_receipts USING (room_id)
523+
"""
524+
525+
# First get summary counts by room / thread for the user. We use the max receipt
526+
# stream ordering of both threaded & unthreaded receipts to compare against the
527+
# summary table.
528+
#
529+
# PostgreSQL and SQLite differ in comparing scalar numerics.
530+
if isinstance(self.database_engine, PostgresEngine):
531+
# GREATEST ignores NULLs.
532+
max_clause = """GREATEST(
533+
threaded_receipt_stream_ordering,
534+
unthreaded_receipt_stream_ordering
535+
)"""
536+
else:
537+
# MAX returns NULL if any are NULL, so COALESCE to 0 first.
538+
max_clause = """MAX(
539+
COALESCE(threaded_receipt_stream_ordering, 0),
540+
COALESCE(unthreaded_receipt_stream_ordering, 0)
541+
)"""
542+
543+
sql = f"""
544+
{receipts_cte}
545+
SELECT eps.room_id, eps.thread_id, notif_count
546+
FROM event_push_summary AS eps
547+
{receipts_joins}
548+
WHERE user_id = ?
549+
AND notif_count != 0
550+
AND (
551+
(last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
552+
OR last_receipt_stream_ordering = {max_clause}
553+
)
554+
"""
555+
txn.execute(sql, args)
556+
557+
seen_thread_ids = set()
558+
room_to_count: Dict[str, int] = defaultdict(int)
559+
560+
for room_id, thread_id, notif_count in txn:
561+
room_to_count[room_id] += notif_count
562+
seen_thread_ids.add(thread_id)
563+
564+
# Now get any event push actions that haven't been rotated using the same OR
565+
# join and filter by receipt and event push summary rotated up to stream ordering.
566+
sql = f"""
567+
{receipts_cte}
568+
SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
569+
FROM event_push_actions AS epa
570+
{receipts_joins}
571+
WHERE user_id = ?
572+
AND epa.notif = 1
573+
AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
574+
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
575+
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
576+
GROUP BY epa.room_id, epa.thread_id
577+
"""
578+
txn.execute(sql, args)
579+
580+
for room_id, thread_id, notif_count in txn:
581+
# Note: only count push actions we have valid summaries for with up to date receipt.
582+
if thread_id not in seen_thread_ids:
583+
continue
584+
room_to_count[room_id] += notif_count
585+
586+
thread_id_clause, thread_ids_args = make_in_list_sql_clause(
587+
self.database_engine, "epa.thread_id", seen_thread_ids
588+
)
589+
590+
# Finally re-check event_push_actions for any rooms not in the summary, ignoring
591+
# the rotated up-to position. This handles the case where a read receipt has arrived
592+
# but not been rotated meaning the summary table is out of date, so we go back to
593+
# the push actions table.
594+
sql = f"""
595+
{receipts_cte}
596+
SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
597+
FROM event_push_actions AS epa
598+
{receipts_joins}
599+
WHERE user_id = ?
600+
AND NOT {thread_id_clause}
601+
AND epa.notif = 1
602+
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
603+
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
604+
GROUP BY epa.room_id
605+
"""
606+
607+
args.extend(thread_ids_args)
608+
txn.execute(sql, args)
609+
610+
for room_id, notif_count in txn:
611+
room_to_count[room_id] += notif_count
612+
613+
return room_to_count
614+
466615
@cached(tree=True, max_entries=5000, iterable=True)
467616
async def get_unread_event_push_actions_by_room_for_user(
468617
self,

tests/storage/test_event_push_actions.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_count_aggregation(self) -> None:
156156

157157
last_event_id: str
158158

159-
def _assert_counts(noitf_count: int, highlight_count: int) -> None:
159+
def _assert_counts(notif_count: int, highlight_count: int) -> None:
160160
counts = self.get_success(
161161
self.store.db_pool.runInteraction(
162162
"get-unread-counts",
@@ -168,13 +168,22 @@ def _assert_counts(noitf_count: int, highlight_count: int) -> None:
168168
self.assertEqual(
169169
counts.main_timeline,
170170
NotifCounts(
171-
notify_count=noitf_count,
171+
notify_count=notif_count,
172172
unread_count=0,
173173
highlight_count=highlight_count,
174174
),
175175
)
176176
self.assertEqual(counts.threads, {})
177177

178+
aggregate_counts = self.get_success(
179+
self.store.db_pool.runInteraction(
180+
"get-aggregate-unread-counts",
181+
self.store._get_unread_counts_by_room_for_user_txn,
182+
user_id,
183+
)
184+
)
185+
self.assertEqual(aggregate_counts[room_id], notif_count)
186+
178187
def _create_event(highlight: bool = False) -> str:
179188
result = self.helper.send_event(
180189
room_id,
@@ -283,7 +292,7 @@ def test_count_aggregation_threads(self) -> None:
283292
last_event_id: str
284293

285294
def _assert_counts(
286-
noitf_count: int,
295+
notif_count: int,
287296
highlight_count: int,
288297
thread_notif_count: int,
289298
thread_highlight_count: int,
@@ -299,7 +308,7 @@ def _assert_counts(
299308
self.assertEqual(
300309
counts.main_timeline,
301310
NotifCounts(
302-
notify_count=noitf_count,
311+
notify_count=notif_count,
303312
unread_count=0,
304313
highlight_count=highlight_count,
305314
),
@@ -318,6 +327,17 @@ def _assert_counts(
318327
else:
319328
self.assertEqual(counts.threads, {})
320329

330+
aggregate_counts = self.get_success(
331+
self.store.db_pool.runInteraction(
332+
"get-aggregate-unread-counts",
333+
self.store._get_unread_counts_by_room_for_user_txn,
334+
user_id,
335+
)
336+
)
337+
self.assertEqual(
338+
aggregate_counts[room_id], notif_count + thread_notif_count
339+
)
340+
321341
def _create_event(
322342
highlight: bool = False, thread_id: Optional[str] = None
323343
) -> str:
@@ -454,7 +474,7 @@ def test_count_aggregation_mixed(self) -> None:
454474
last_event_id: str
455475

456476
def _assert_counts(
457-
noitf_count: int,
477+
notif_count: int,
458478
highlight_count: int,
459479
thread_notif_count: int,
460480
thread_highlight_count: int,
@@ -470,7 +490,7 @@ def _assert_counts(
470490
self.assertEqual(
471491
counts.main_timeline,
472492
NotifCounts(
473-
notify_count=noitf_count,
493+
notify_count=notif_count,
474494
unread_count=0,
475495
highlight_count=highlight_count,
476496
),
@@ -489,6 +509,17 @@ def _assert_counts(
489509
else:
490510
self.assertEqual(counts.threads, {})
491511

512+
aggregate_counts = self.get_success(
513+
self.store.db_pool.runInteraction(
514+
"get-aggregate-unread-counts",
515+
self.store._get_unread_counts_by_room_for_user_txn,
516+
user_id,
517+
)
518+
)
519+
self.assertEqual(
520+
aggregate_counts[room_id], notif_count + thread_notif_count
521+
)
522+
492523
def _create_event(
493524
highlight: bool = False, thread_id: Optional[str] = None
494525
) -> str:
@@ -646,7 +677,7 @@ def _create_event(type: str, content: JsonDict) -> str:
646677
)
647678
return result["event_id"]
648679

649-
def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
680+
def _assert_counts(notif_count: int, thread_notif_count: int) -> None:
650681
counts = self.get_success(
651682
self.store.db_pool.runInteraction(
652683
"get-unread-counts",
@@ -658,7 +689,7 @@ def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
658689
self.assertEqual(
659690
counts.main_timeline,
660691
NotifCounts(
661-
notify_count=noitf_count, unread_count=0, highlight_count=0
692+
notify_count=notif_count, unread_count=0, highlight_count=0
662693
),
663694
)
664695
if thread_notif_count:

0 commit comments

Comments
 (0)