Avoid blocking lazy-loading `/sync`s during partial joins (#13477)

Use a state filter or accept partial state in a few places where we
request state, to avoid blocking.

To make lazy-loading `/sync`s work, we need to provide the memberships
of event senders, which are not guaranteed to be in the room state.
Instead we dig through auth events for memberships to present to
clients. The auth events of an event are guaranteed to contain a
passable membership event, otherwise the event would have been rejected.

Note that this only covers the common code paths encountered during
testing. There has been no exhaustive checking of all sync code paths.

Fixes #13146.

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2022-08-18 11:53:02 +01:00 committed by GitHub
parent 49d04e43df
commit 84169a82dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 245 additions and 35 deletions

1
changelog.d/13477.misc Normal file
View File

@ -0,0 +1 @@
Faster room joins: Avoid blocking lazy-loading `/sync`s during partial joins due to remote memberships. Pull remote memberships from auth events instead of the room state.

View File

@ -16,9 +16,11 @@ import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Collection,
Dict, Dict,
FrozenSet, FrozenSet,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -517,10 +519,17 @@ class SyncHandler:
# ensure that we always include current state in the timeline # ensure that we always include current state in the timeline
current_state_ids: FrozenSet[str] = frozenset() current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents): if any(e.is_state() for e in recents):
# FIXME(faster_joins): We use the partial state here as
# we don't want to block `/sync` on finishing a lazy join.
# Which should be fine once
# https://github.com/matrix-org/synapse/issues/12989 is resolved,
# since we shouldn't reach here anymore?
# Note that we use the current state as a whitelist for filtering
# `recents`, so partial state is only a problem when a membership
# event turns up in `recents` but has not made it into the current
# state.
current_state_ids_map = ( current_state_ids_map = (
await self._state_storage_controller.get_current_state_ids( await self.store.get_partial_current_state_ids(room_id)
room_id
)
) )
current_state_ids = frozenset(current_state_ids_map.values()) current_state_ids = frozenset(current_state_ids_map.values())
@ -589,7 +598,13 @@ class SyncHandler:
if any(e.is_state() for e in loaded_recents): if any(e.is_state() for e in loaded_recents):
# FIXME(faster_joins): We use the partial state here as # FIXME(faster_joins): We use the partial state here as
# we don't want to block `/sync` on finishing a lazy join. # we don't want to block `/sync` on finishing a lazy join.
# Is this the correct way of doing it? # Which should be fine once
# https://github.com/matrix-org/synapse/issues/12989 is resolved,
# since we shouldn't reach here anymore?
# Note that we use the current state as a whitelist for filtering
# `loaded_recents`, so partial state is only a problem when a
# membership event turns up in `loaded_recents` but has not made it
# into the current state.
current_state_ids_map = ( current_state_ids_map = (
await self.store.get_partial_current_state_ids(room_id) await self.store.get_partial_current_state_ids(room_id)
) )
@ -637,7 +652,10 @@ class SyncHandler:
) )
async def get_state_after_event( async def get_state_after_event(
self, event_id: str, state_filter: Optional[StateFilter] = None self,
event_id: str,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Get the room state after the given event Get the room state after the given event
@ -645,9 +663,14 @@ class SyncHandler:
Args: Args:
event_id: event of interest event_id: event of interest
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`.
""" """
state_ids = await self._state_storage_controller.get_state_ids_for_event( state_ids = await self._state_storage_controller.get_state_ids_for_event(
event_id, state_filter=state_filter or StateFilter.all() event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
) )
# using get_metadata_for_events here (instead of get_event) sidesteps an issue # using get_metadata_for_events here (instead of get_event) sidesteps an issue
@ -670,6 +693,7 @@ class SyncHandler:
room_id: str, room_id: str,
stream_position: StreamToken, stream_position: StreamToken,
state_filter: Optional[StateFilter] = None, state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the room state at a particular stream position """Get the room state at a particular stream position
@ -677,6 +701,9 @@ class SyncHandler:
room_id: room for which to get state room_id: room for which to get state
stream_position: point at which to get state stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`.
""" """
# FIXME: This gets the state at the latest event before the stream ordering, # FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time # which might not be the same as the "current state" of the room at the time
@ -688,7 +715,9 @@ class SyncHandler:
if last_event_id: if last_event_id:
state = await self.get_state_after_event( state = await self.get_state_after_event(
last_event_id, state_filter=state_filter or StateFilter.all() last_event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
) )
else: else:
@ -891,7 +920,15 @@ class SyncHandler:
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
# The memberships needed for events in the timeline. # The memberships needed for events in the timeline.
# Only calculated when `lazy_load_members` is on. # Only calculated when `lazy_load_members` is on.
members_to_fetch = None members_to_fetch: Optional[Set[str]] = None
# A dictionary mapping user IDs to the first event in the timeline sent by
# them. Only calculated when `lazy_load_members` is on.
first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
# The contribution to the room state from state events in the timeline.
# Only contains the last event for any given state key.
timeline_state: StateMap[str]
lazy_load_members = sync_config.filter_collection.lazy_load_members() lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = ( include_redundant_members = (
@ -902,10 +939,23 @@ class SyncHandler:
# We only request state for the members needed to display the # We only request state for the members needed to display the
# timeline: # timeline:
members_to_fetch = { timeline_state = {}
event.sender # FIXME: we also care about invite targets etc.
for event in batch.events members_to_fetch = set()
} first_event_by_sender_map = {}
for event in batch.events:
# Build the map from user IDs to the first timeline event they sent.
if event.sender not in first_event_by_sender_map:
first_event_by_sender_map[event.sender] = event
# We need the event's sender, unless their membership was in a
# previous timeline event.
if (EventTypes.Member, event.sender) not in timeline_state:
members_to_fetch.add(event.sender)
# FIXME: we also care about invite targets etc.
if event.is_state():
timeline_state[(event.type, event.state_key)] = event.event_id
if full_state: if full_state:
# always make sure we LL ourselves so we know we're in the room # always make sure we LL ourselves so we know we're in the room
@ -915,16 +965,21 @@ class SyncHandler:
members_to_fetch.add(sync_config.user.to_string()) members_to_fetch.add(sync_config.user.to_string())
state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
else:
state_filter = StateFilter.all()
# The contribution to the room state from state events in the timeline. # We are happy to use partial state to compute the `/sync` response.
# Only contains the last event for any given state key. # Since partial state may not include the lazy-loaded memberships we
timeline_state = { # require, we fix up the state response afterwards with memberships from
(event.type, event.state_key): event.event_id # auth events.
for event in batch.events await_full_state = False
if event.is_state() else:
} timeline_state = {
(event.type, event.state_key): event.event_id
for event in batch.events
if event.is_state()
}
state_filter = StateFilter.all()
await_full_state = True
# Now calculate the state to return in the sync response for the room. # Now calculate the state to return in the sync response for the room.
# This is more or less the change in state between the end of the previous # This is more or less the change in state between the end of the previous
@ -936,19 +991,26 @@ class SyncHandler:
if batch: if batch:
state_at_timeline_end = ( state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id,
state_filter=state_filter,
await_full_state=await_full_state,
) )
) )
state_at_timeline_start = ( state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id,
state_filter=state_filter,
await_full_state=await_full_state,
) )
) )
else: else:
state_at_timeline_end = await self.get_state_at( state_at_timeline_end = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter room_id,
stream_position=now_token,
state_filter=state_filter,
await_full_state=await_full_state,
) )
state_at_timeline_start = state_at_timeline_end state_at_timeline_start = state_at_timeline_end
@ -964,14 +1026,19 @@ class SyncHandler:
if batch: if batch:
state_at_timeline_start = ( state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id,
state_filter=state_filter,
await_full_state=await_full_state,
) )
) )
else: else:
# We can get here if the user has ignored the senders of all # We can get here if the user has ignored the senders of all
# the recent events. # the recent events.
state_at_timeline_start = await self.get_state_at( state_at_timeline_start = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter room_id,
stream_position=now_token,
state_filter=state_filter,
await_full_state=await_full_state,
) )
# for now, we disable LL for gappy syncs - see # for now, we disable LL for gappy syncs - see
@ -993,20 +1060,28 @@ class SyncHandler:
# is indeed the case. # is indeed the case.
assert since_token is not None assert since_token is not None
state_at_previous_sync = await self.get_state_at( state_at_previous_sync = await self.get_state_at(
room_id, stream_position=since_token, state_filter=state_filter room_id,
stream_position=since_token,
state_filter=state_filter,
await_full_state=await_full_state,
) )
if batch: if batch:
state_at_timeline_end = ( state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id,
state_filter=state_filter,
await_full_state=await_full_state,
) )
) )
else: else:
# We can get here if the user has ignored the senders of all # We can get here if the user has ignored the senders of all
# the recent events. # the recent events.
state_at_timeline_end = await self.get_state_at( state_at_timeline_end = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter room_id,
stream_position=now_token,
state_filter=state_filter,
await_full_state=await_full_state,
) )
state_ids = _calculate_state( state_ids = _calculate_state(
@ -1036,8 +1111,23 @@ class SyncHandler:
(EventTypes.Member, member) (EventTypes.Member, member)
for member in members_to_fetch for member in members_to_fetch
), ),
await_full_state=False,
) )
# If we only have partial state for the room, `state_ids` may be missing the
# memberships we wanted. We attempt to find some by digging through the auth
# events of timeline events.
if lazy_load_members and await self.store.is_partial_state_room(room_id):
assert members_to_fetch is not None
assert first_event_by_sender_map is not None
additional_state_ids = (
await self._find_missing_partial_state_memberships(
room_id, members_to_fetch, first_event_by_sender_map, state_ids
)
)
state_ids = {**state_ids, **additional_state_ids}
# At this point, if `lazy_load_members` is enabled, `state_ids` includes # At this point, if `lazy_load_members` is enabled, `state_ids` includes
# the memberships of all event senders in the timeline. This is because we # the memberships of all event senders in the timeline. This is because we
# may not have sent the memberships in a previous sync. # may not have sent the memberships in a previous sync.
@ -1086,6 +1176,99 @@ class SyncHandler:
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
} }
async def _find_missing_partial_state_memberships(
self,
room_id: str,
members_to_fetch: Collection[str],
events_with_membership_auth: Mapping[str, EventBase],
found_state_ids: StateMap[str],
) -> StateMap[str]:
"""Finds missing memberships from a set of auth events and returns them as a
state map.
Args:
room_id: The partial state room to find the remaining memberships for.
members_to_fetch: The memberships to find.
events_with_membership_auth: A mapping from user IDs to events whose auth
events are known to contain their membership.
found_state_ids: A dict from (type, state_key) -> state_event_id, containing
memberships that have been previously found. Entries in
`members_to_fetch` that have a membership in `found_state_ids` are
ignored.
Returns:
A dict from ("m.room.member", state_key) -> state_event_id, containing the
memberships missing from `found_state_ids`.
Raises:
KeyError: if `events_with_membership_auth` does not have an entry for a
missing membership. Memberships in `found_state_ids` do not need an
entry in `events_with_membership_auth`.
"""
additional_state_ids: MutableStateMap[str] = {}
# Tracks the missing members for logging purposes.
missing_members = set()
# Identify memberships missing from `found_state_ids` and pick out the auth
# events in which to look for them.
auth_event_ids: Set[str] = set()
for member in members_to_fetch:
if (EventTypes.Member, member) in found_state_ids:
continue
missing_members.add(member)
event_with_membership_auth = events_with_membership_auth[member]
auth_event_ids.update(event_with_membership_auth.auth_event_ids())
auth_events = await self.store.get_events(auth_event_ids)
# Run through the missing memberships once more, picking out the memberships
# from the pile of auth events we have just fetched.
for member in members_to_fetch:
if (EventTypes.Member, member) in found_state_ids:
continue
event_with_membership_auth = events_with_membership_auth[member]
# Dig through the auth events to find the desired membership.
for auth_event_id in event_with_membership_auth.auth_event_ids():
# We only store events once we have all their auth events,
# so the auth event must be in the pile we have just
# fetched.
auth_event = auth_events[auth_event_id]
if (
auth_event.type == EventTypes.Member
and auth_event.state_key == member
):
missing_members.remove(member)
additional_state_ids[
(EventTypes.Member, member)
] = auth_event.event_id
break
if missing_members:
# There really shouldn't be any missing memberships now. Either:
# * we couldn't find an auth event, which shouldn't happen because we do
# not persist events with persisting their auth events first, or
# * the set of auth events did not contain a membership we wanted, which
# means our caller didn't compute the events in `members_to_fetch`
# correctly, or we somehow accepted an event whose auth events were
# dodgy.
logger.error(
"Failed to find memberships for %s in partial state room "
"%s in the auth events of %s.",
missing_members,
room_id,
[
events_with_membership_auth[member].event_id
for member in missing_members
],
)
return additional_state_ids
async def unread_notifs_for_room_id( async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig self, room_id: str, sync_config: SyncConfig
) -> NotifCounts: ) -> NotifCounts:
@ -1730,7 +1913,11 @@ class SyncHandler:
continue continue
if room_id in sync_result_builder.joined_room_ids or has_join: if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = await self.get_state_at(room_id, since_token) old_state_ids = await self.get_state_at(
room_id,
since_token,
state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None old_mem_ev = None
if old_mem_ev_id: if old_mem_ev_id:
@ -1756,7 +1943,13 @@ class SyncHandler:
newly_left_rooms.append(room_id) newly_left_rooms.append(room_id)
else: else:
if not old_state_ids: if not old_state_ids:
old_state_ids = await self.get_state_at(room_id, since_token) old_state_ids = await self.get_state_at(
room_id,
since_token,
state_filter=StateFilter.from_types(
[(EventTypes.Member, user_id)]
),
)
old_mem_ev_id = old_state_ids.get( old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None (EventTypes.Member, user_id), None
) )

View File

@ -234,6 +234,7 @@ class StateStorageController:
self, self,
event_ids: Collection[str], event_ids: Collection[str],
state_filter: Optional[StateFilter] = None, state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> Dict[str, StateMap[str]]: ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
@ -242,6 +243,9 @@ class StateStorageController:
Args: Args:
event_ids: events whose state should be returned event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at these events and `state_filter` is not satisfied by partial state.
Defaults to `True`.
Returns: Returns:
A dict from event_id -> (type, state_key) -> event_id A dict from event_id -> (type, state_key) -> event_id
@ -250,8 +254,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
await_full_state = True if (
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): await_full_state
and state_filter
and not state_filter.must_await_full_state(self._is_mine_id)
):
# Full state is not required if the state filter is restrictive enough.
await_full_state = False await_full_state = False
event_to_groups = await self.get_state_group_for_events( event_to_groups = await self.get_state_group_for_events(
@ -294,7 +302,10 @@ class StateStorageController:
@trace @trace
async def get_state_ids_for_event( async def get_state_ids_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None self,
event_id: str,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
@ -302,6 +313,9 @@ class StateStorageController:
Args: Args:
event_id: event whose state should be returned event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`.
Returns: Returns:
A dict from (type, state_key) -> state_event_id A dict from (type, state_key) -> state_event_id
@ -311,7 +325,9 @@ class StateStorageController:
outlier or is unknown) outlier or is unknown)
""" """
state_map = await self.get_state_ids_for_events( state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all() [event_id],
state_filter or StateFilter.all(),
await_full_state=await_full_state,
) )
return state_map[event_id] return state_map[event_id]