Pull out less state when handling gaps mk2 (#12852)

This commit is contained in:
Erik Johnston 2022-05-26 10:48:12 +01:00 committed by GitHub
parent 1b338476af
commit b83bc5fab5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 236 additions and 127 deletions

1
changelog.d/12852.misc Normal file
View File

@ -0,0 +1 @@
Pull out less state when handling gaps in room DAG.

View File

@ -274,7 +274,7 @@ class FederationEventHandler:
affected=pdu.event_id,
)
await self._process_received_pdu(origin, pdu, state=None)
await self._process_received_pdu(origin, pdu, state_ids=None)
async def on_send_membership_event(
self, origin: str, event: EventBase
@ -463,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in state
},
partial_state=partial_state,
)
@ -512,12 +514,12 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
state = await self._resolve_state_at_missing_prevs(destination, event)
state_ids = await self._resolve_state_at_missing_prevs(destination, event)
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event=state_ids,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@ -767,11 +769,12 @@ class FederationEventHandler:
return
try:
state = await self._resolve_state_at_missing_prevs(origin, event)
state_ids = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
await self._process_received_pdu(
origin, event, state=state, backfilled=backfilled
origin, event, state_ids=state_ids, backfilled=backfilled
)
except FederationError as e:
if e.code == 403:
@ -781,7 +784,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@ -808,8 +811,8 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of
the events in the state at `event`.
if we already had all the prev events, `None`. Otherwise, returns
the event ids of the state at `event`.
"""
room_id = event.room_id
event_id = event.event_id
@ -829,7 +832,7 @@ class FederationEventHandler:
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: event}
try:
# Get the state of the events we know about
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
@ -849,40 +852,23 @@ class FederationEventHandler:
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
remote_state = await self._get_state_after_missing_prev_event(
dest, room_id, p
remote_state_map = (
await self._get_state_ids_after_missing_prev_event(
dest, room_id, p
)
)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
)
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.as_is,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@ -894,14 +880,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
return state
return state_map
async def _get_state_after_missing_prev_event(
async def _get_state_ids_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@ -910,7 +896,7 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, including the event itself
The event ids of the state *after* the given event.
"""
(
state_event_ids,
@ -925,19 +911,17 @@ class FederationEventHandler:
len(auth_event_ids),
)
# start by just trying to fetch the events from the store
# Start by checking events we already have in the DB
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self._store.get_events(
desired_events, allow_rejected=True
)
have_events = await self._store.have_seen_events(room_id, desired_events)
missing_desired_events = desired_events - fetched_events.keys()
missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
len(fetched_events),
len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@ -948,7 +932,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@ -974,47 +958,51 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
await self._store.get_events(missing_desired_events, allow_rejected=True)
)
# We now need to fill out the state map, which involves fetching the
# type and state key for each event ID in the state.
state_map = {}
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
for state_event_id, metadata in event_metadata.items():
if metadata.room_id != room_id:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
#
# This can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
state_event_id,
metadata.room_id,
room_id,
)
continue
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
if metadata.state_key is None:
logger.warning(
"Remote server gave us non-state event in state: %s", state_event_id
)
continue
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
del fetched_events[bad_event_id]
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id)
remote_event = await self._store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - fetched_events.keys()
failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@ -1022,14 +1010,12 @@ class FederationEventHandler:
failed_to_fetch,
)
remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id
return remote_state
return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@ -1056,7 +1042,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
state_ids: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@ -1078,7 +1064,7 @@ class FederationEventHandler:
event: event to be persisted
state: Normally None, but if we are handling a gap in the graph
state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
@ -1090,7 +1076,8 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
event, old_state=state
event,
state_ids_before_event=state_ids,
)
context = await self._check_event_auth(
origin,
@ -1107,7 +1094,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
await self._check_for_soft_fail(event, state, origin=origin)
await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@ -1589,7 +1576,7 @@ class FederationEventHandler:
async def _check_for_soft_fail(
self,
event: EventBase,
state: Optional[Iterable[EventBase]],
state_ids: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@ -1597,7 +1584,7 @@ class FederationEventHandler:
Args:
event
state: The state at the event if we don't have all the event's prev events
state_ids: The state at the event if we don't have all the event's prev events
origin: The host the event originates from.
"""
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@ -1613,7 +1600,7 @@ class FederationEventHandler:
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state".
if state is not None:
if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
@ -1626,17 +1613,20 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets_d = await self._state_storage.get_state_groups(
state_sets_d = await self._state_storage.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self._state_handler.resolve_events(
room_version, state_sets, event
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_sets.append(state_ids)
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
)
)
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else:
current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids

View File

@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.types import (
MutableStateMap,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
@ -1022,8 +1029,35 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
metadata = await self.store.get_metadata_for_events(state_event_ids)
state_map_for_event: MutableStateMap[str] = {}
for state_id in state_event_ids:
data = metadata.get(state_id)
if data is None:
# We're trying to persist a new historical batch of events
# with the given state, e.g. via
# `RoomBatchSendEventRestServlet`. The state can be inferred
# by Synapse or set directly by the client.
#
# Either way, we should have persisted all the state before
# getting here.
raise Exception(
f"State event {state_id} not found in DB,"
" Synapse should have persisted it before using it."
)
if data.state_key is None:
raise Exception(
f"Trying to set non-state event {state_id} as state"
)
state_map_for_event[(data.event_type, data.state_key)] = state_id
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map_for_event,
)
else:
context = await self.state.compute_event_context(event)

View File

@ -261,7 +261,7 @@ class StateHandler:
async def compute_event_context(
self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@ -273,12 +273,12 @@ class StateHandler:
Args:
event:
old_state: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
state_ids_before_event: The event ids of the state before the event if
it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling.
partial_state: True if `state_ids_before_event` is partial and omits
non-critical membership events
Returns:
The event context.
"""
@ -286,13 +286,11 @@ class StateHandler:
assert not event.internal_metadata.is_outlier()
#
# first of all, figure out the state before the event
# first of all, figure out the state before the event, unless we
# already have it.
#
if old_state:
if state_ids_before_event:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
}
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None

View File

@ -16,6 +16,8 @@ import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@ -26,6 +28,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -43,6 +47,15 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventMetadata:
"""Returned by `get_metadata_for_events`"""
room_id: str
event_type: str
state_key: Optional[str]
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
async def get_metadata_for_events(
self, event_ids: Collection[str]
) -> Dict[str, EventMetadata]:
"""Get some metadata (room_id, type, state_key) for the given events.
This method is a faster alternative than fetching the full events from
the DB, and should be used when the full event is not needed.
Returns metadata for rejected and redacted events. Events that have not
been persisted are omitted from the returned dict.
"""
def get_metadata_for_events_txn(
txn: LoggingTransaction,
batch_ids: Collection[str],
) -> Dict[str, EventMetadata]:
clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", batch_ids
)
sql = f"""
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
LEFT JOIN state_events USING (event_id)
WHERE {clause}
"""
txn.execute(sql, args)
return {
event_id: EventMetadata(
room_id=room_id, event_type=event_type, state_key=state_key
)
for event_id, room_id, event_type, state_key in txn
}
result_map: Dict[str, EventMetadata] = {}
for batch_ids in batch_iter(event_ids, 1000):
result_map.update(
await self.db_pool.runInteraction(
"get_metadata_for_events",
get_metadata_for_events_txn,
batch_ids=batch_ids,
)
)
return result_map
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.

View File

@ -276,7 +276,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
self.OTHER_SERVER_NAME,
event,
state_ids={
(e.type, e.state_key): e.event_id for e in current_state
},
)
)

View File

@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(event, old_state=state)
self.state.compute_event_context(event, state_ids_before_event=state)
)
self.get_success(self.persistence.persist_event(event, context))
@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(self.state.get_current_state(self.room_id))
self.get_success(self.state.get_current_state_ids(self.room_id))
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
self.state.compute_event_context(
remote_event_2, old_state=state_before_gap.values()
remote_event_2,
state_ids_before_event=state_before_gap,
)
)
@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])

View File

@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())