1
1
# Copyright 2014-2016 OpenMarket Ltd
2
+ # Copyright 2022 The Matrix.org Foundation C.I.C.
2
3
#
3
4
# Licensed under the Apache License, Version 2.0 (the "License");
4
5
# you may not use this file except in compliance with the License.
15
16
from typing import (
16
17
TYPE_CHECKING ,
17
18
Awaitable ,
19
+ Callable ,
18
20
Collection ,
19
21
Dict ,
20
22
Iterable ,
@@ -532,6 +534,40 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
532
534
new_all , new_excludes , new_wildcards , new_concrete_keys
533
535
)
534
536
537
+ def must_await_full_state (self , is_mine_id : Callable [[str ], bool ]) -> bool :
538
+ """Check if we need to wait for full state to complete to calculate this state
539
+
540
+ If we have a state filter which is completely satisfied even with partial
541
+ state, then we don't need to await_full_state before we can return it.
542
+
543
+ Args:
544
+ is_mine_id: a callable which confirms if a given state_key matches a mxid
545
+ of a local user
546
+ """
547
+
548
+ # XXX: can we be certain that the state at an event never changes (only gets
549
+ # enlarged)?
550
+
551
+ # if we haven't requested membership events, then it depends on the value of
552
+ # 'include_others'
553
+ if EventTypes .Member not in self .types :
554
+ return self .include_others
555
+
556
+ # if we're looking for *all* membership events, then we have to wait
557
+ member_state_keys = self .types [EventTypes .Member ]
558
+ if member_state_keys is None :
559
+ return True
560
+
561
+ # otherwise, consider whose membership we are looking for. If it's entirely
562
+ # local users, then we don't need to wait.
563
+ for state_key in member_state_keys :
564
+ if not is_mine_id (state_key ):
565
+ # remote user
566
+ return True
567
+
568
+ # local users only
569
+ return False
570
+
535
571
536
572
_ALL_STATE_FILTER = StateFilter (types = frozendict (), include_others = True )
537
573
_ALL_NON_MEMBER_STATE_FILTER = StateFilter (
@@ -544,6 +580,7 @@ class StateGroupStorage:
544
580
"""High level interface to fetching state for event."""
545
581
546
582
def __init__ (self , hs : "HomeServer" , stores : "Databases" ):
583
+ self ._is_mine_id = hs .is_mine_id
547
584
self .stores = stores
548
585
self ._partial_state_events_tracker = PartialStateEventsTracker (stores .main )
549
586
@@ -675,7 +712,13 @@ async def get_state_for_events(
675
712
RuntimeError if we don't have a state group for one or more of the events
676
713
(ie they are outliers or unknown)
677
714
"""
678
- event_to_groups = await self ._get_state_group_for_events (event_ids )
715
+ await_full_state = True
716
+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
717
+ await_full_state = False
718
+
719
+ event_to_groups = await self ._get_state_group_for_events (
720
+ event_ids , await_full_state = await_full_state
721
+ )
679
722
680
723
groups = set (event_to_groups .values ())
681
724
group_to_state = await self .stores .state ._get_state_for_groups (
@@ -699,7 +742,9 @@ async def get_state_for_events(
699
742
return {event : event_to_state [event ] for event in event_ids }
700
743
701
744
async def get_state_ids_for_events (
702
- self , event_ids : Collection [str ], state_filter : Optional [StateFilter ] = None
745
+ self ,
746
+ event_ids : Collection [str ],
747
+ state_filter : Optional [StateFilter ] = None ,
703
748
) -> Dict [str , StateMap [str ]]:
704
749
"""
705
750
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -716,7 +761,13 @@ async def get_state_ids_for_events(
716
761
RuntimeError if we don't have a state group for one or more of the events
717
762
(ie they are outliers or unknown)
718
763
"""
719
- event_to_groups = await self ._get_state_group_for_events (event_ids )
764
+ await_full_state = True
765
+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
766
+ await_full_state = False
767
+
768
+ event_to_groups = await self ._get_state_group_for_events (
769
+ event_ids , await_full_state = await_full_state
770
+ )
720
771
721
772
groups = set (event_to_groups .values ())
722
773
group_to_state = await self .stores .state ._get_state_for_groups (
@@ -802,7 +853,7 @@ async def _get_state_group_for_events(
802
853
Args:
803
854
event_ids: events to get state groups for
804
855
await_full_state: if true, will block if we do not yet have complete
805
- state at this event .
856
+ state at these events .
806
857
"""
807
858
if await_full_state :
808
859
await self ._partial_state_events_tracker .await_full_state (event_ids )
0 commit comments