Factor out common code for persisting fetched auth events (#10896)

* Factor more stuff out of `_get_events_and_persist`

It turns out that the event-sorting algorithm in `_get_events_and_persist` is
also useful in other circumstances. Here we move the current
`_auth_and_persist_fetched_events` to `_auth_and_persist_fetched_events_inner`,
and then factor the sorting part out to `_auth_and_persist_fetched_events`.

* `_get_remote_auth_chain_for_event`: remove redundant `outlier` assignment

`get_event_auth` returns events with the outlier flag already set, so this is
redundant (though we need to update a test where `get_event_auth` is mocked).

* `_get_remote_auth_chain_for_event`: move existing-event tests earlier

Move a couple of tests outside the loop. This is a bit inefficient for now, but
a future commit will make it better. It should be functionally identical.

* `_get_remote_auth_chain_for_event`: use `_auth_and_persist_fetched_events`

We can use the same codepath for persisting the events fetched as part of an
auth chain as for those fetched individually by `_get_events_and_persist` for
building the state at a backwards extremity.

* `_get_remote_auth_chain_for_event`: use a dict for efficiency

`_auth_and_persist_fetched_events` sorts the events itself, so we no longer
need to care about maintaining the ordering from `get_event_auth` (and no
longer need to sort by depth in `get_event_auth`).

That means that we can use a map, making it easier to filter out events we
already have, etc.

* changelog

* `_auth_and_persist_fetched_events`: improve docstring
This commit is contained in:
Richard van der Hoff 2021-09-24 11:56:33 +01:00 committed by GitHub
parent 261c9763c4
commit 85551b7a85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 58 deletions

1
changelog.d/10896.misc Normal file
View File

@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.

View File

@ -501,8 +501,6 @@ class FederationClient(FederationBase):
destination, auth_chain, outlier=True, room_version=room_version destination, auth_chain, outlier=True, room_version=room_version
) )
signed_auth.sort(key=lambda e: e.depth)
return signed_auth return signed_auth
def _is_unknown_endpoint( def _is_unknown_endpoint(

View File

@ -1080,7 +1080,7 @@ class FederationEventHandler:
room_version = await self._store.get_room_version(room_id) room_version = await self._store.get_room_version(room_id)
event_map: Dict[str, EventBase] = {} events: List[EventBase] = []
async def get_event(event_id: str) -> None: async def get_event(event_id: str) -> None:
with nested_logging_context(event_id): with nested_logging_context(event_id):
@ -1098,8 +1098,7 @@ class FederationEventHandler:
event_id, event_id,
) )
return return
events.append(event)
event_map[event.event_id] = event
except Exception as e: except Exception as e:
logger.warning( logger.warning(
@ -1110,11 +1109,29 @@ class FederationEventHandler:
) )
await concurrently_execute(get_event, event_ids, 5) await concurrently_execute(get_event, event_ids, 5)
logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids)) logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
await self._auth_and_persist_fetched_events(destination, room_id, events)
async def _auth_and_persist_fetched_events(
self, origin: str, room_id: str, events: Iterable[EventBase]
) -> None:
"""Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
The events to be persisted must be outliers.
We first sort the events to make sure that we process each event's auth_events
before the event itself, and then auth and persist them.
Notifies about the events where appropriate.
Params:
origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
events: the events that have been fetched
"""
event_map = {event.event_id: event for event in events}
# we now need to auth the events in an order which ensures that each event's
# auth_events are authed before the event itself.
#
# XXX: it might be possible to kick this process off in parallel with fetching # XXX: it might be possible to kick this process off in parallel with fetching
# the events. # the events.
while event_map: while event_map:
@ -1141,22 +1158,18 @@ class FederationEventHandler:
"Persisting %i of %i remaining events", len(roots), len(event_map) "Persisting %i of %i remaining events", len(roots), len(event_map)
) )
await self._auth_and_persist_fetched_events(destination, room_id, roots) await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
for ev in roots: for ev in roots:
del event_map[ev.event_id] del event_map[ev.event_id]
async def _auth_and_persist_fetched_events( async def _auth_and_persist_fetched_events_inner(
self, origin: str, room_id: str, fetched_events: Collection[EventBase] self, origin: str, room_id: str, fetched_events: Collection[EventBase]
) -> None: ) -> None:
"""Persist the events fetched by _get_events_and_persist. """Helper for _auth_and_persist_fetched_events
The events should not depend on one another, e.g. this should be used to persist Persists a batch of events where we have (theoretically) already persisted all
a bunch of outliers, but not a chunk of individual events that depend of their auth events.
on each other for state calculations.
We also assume that all of the auth events for all of the events have already
been persisted.
Notifies about the events where appropriate. Notifies about the events where appropriate.
@ -1164,7 +1177,7 @@ class FederationEventHandler:
origin: where the events came from origin: where the events came from
room_id: the room that the events are meant to be in (though this has room_id: the room that the events are meant to be in (though this has
not yet been checked) not yet been checked)
event_id: map from event_id -> event for the fetched events fetched_events: the events to persist
""" """
# get all the auth events for all the events in this batch. By now, they should # get all the auth events for all the events in this batch. By now, they should
# have been persisted. # have been persisted.
@ -1558,53 +1571,33 @@ class FederationEventHandler:
event_id: the event for which we are lacking auth events event_id: the event for which we are lacking auth events
""" """
try: try:
remote_auth_chain = await self._federation_client.get_event_auth( remote_event_map = {
e.event_id: e
for e in await self._federation_client.get_event_auth(
destination, room_id, event_id destination, room_id, event_id
) )
}
except RequestSendFailed as e1: except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the # The other side isn't around or doesn't implement the
# endpoint, so lets just bail out. # endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e1) logger.info("Failed to get event auth from remote: %s", e1)
return return
logger.info("/event_auth returned %i events", len(remote_event_map))
# `event` may be returned, but we should not yet process it.
remote_event_map.pop(event_id, None)
# nor should we reprocess any events we have already seen.
seen_remotes = await self._store.have_seen_events( seen_remotes = await self._store.have_seen_events(
room_id, [e.event_id for e in remote_auth_chain] room_id, remote_event_map.keys()
) )
for s in seen_remotes:
remote_event_map.pop(s, None)
for auth_event in remote_auth_chain: await self._auth_and_persist_fetched_events(
if auth_event.event_id in seen_remotes: destination, room_id, remote_event_map.values()
continue
if auth_event.event_id == event_id:
continue
try:
auth_ids = auth_event.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
auth_event.internal_metadata.outlier = True
logger.debug(
"_check_event_auth %s missing_auth: %s",
event_id,
auth_event.event_id,
) )
missing_auth_event_context = EventContext.for_outlier()
missing_auth_event_context = await self._check_event_auth(
destination,
auth_event,
missing_auth_event_context,
claimed_auth_event_map=auth,
)
await self.persist_events_and_notify(
room_id,
[(auth_event, missing_auth_event_context)],
)
except AuthError:
pass
async def _update_context_for_auth_events( async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]

View File

@ -308,7 +308,12 @@ class FederationTestCase(unittest.HomeserverTestCase):
async def get_event_auth( async def get_event_auth(
destination: str, room_id: str, event_id: str destination: str, room_id: str, event_id: str
) -> List[EventBase]: ) -> List[EventBase]:
return auth_events return [
event_from_pdu_json(
ae.get_pdu_json(), room_version=room_version, outlier=True
)
for ae in auth_events
]
self.handler.federation_client.get_event_auth = get_event_auth self.handler.federation_client.get_event_auth = get_event_auth