Get rid of `_auth_and_persist_event` (#10781)

This is only called in two places, and the code seems much clearer without it.
This commit is contained in:
Richard van der Hoff 2021-09-08 19:03:08 +01:00 committed by GitHub
parent 03caba6577
commit abedf7d77f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 63 deletions

1
changelog.d/10781.misc Normal file
View File

@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.

View File

@ -909,12 +909,18 @@ class FederationEventHandler:
context = await self._state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(
origin, event, context, state=state, backfilled=backfilled
context = await self._check_event_auth(
origin,
event,
context,
state=state,
backfilled=backfilled,
)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
await self._run_push_actions_and_persist_event(event, context, backfilled)
if backfilled:
return
@ -1239,51 +1245,6 @@ class FederationEventHandler:
],
)
async def _auth_and_persist_event(
self,
origin: str,
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
backfilled: bool = False,
) -> None:
"""
Process an event by performing auth checks and then persisting to the database.
Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
claimed_auth_event_map:
A map of (type, state_key) => event for the event's claimed auth_events.
Possibly incomplete, and possibly including events that are not yet
persisted, or authed, or in the right room.
Only populated when populating outliers.
backfilled: True if the event was backfilled.
"""
# claimed_auth_event_map should be given iff the event is an outlier
assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
context = await self._check_event_auth(
origin,
event,
context,
state=state,
claimed_auth_event_map=claimed_auth_event_map,
backfilled=backfilled,
)
await self._run_push_actions_and_persist_event(event, context, backfilled)
async def _check_event_auth(
self,
origin: str,
@ -1558,39 +1519,45 @@ class FederationEventHandler:
event.room_id, [e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
if e.event_id in seen_remotes:
for auth_event in remote_auth_chain:
if auth_event.event_id in seen_remotes:
continue
if e.event_id == event.event_id:
if auth_event.event_id == event.event_id:
continue
try:
auth_ids = e.auth_event_ids()
auth_ids = auth_event.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
e.internal_metadata.outlier = True
auth_event.internal_metadata.outlier = True
logger.debug(
"_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
auth_event.event_id,
)
missing_auth_event_context = (
await self._state_handler.compute_event_context(e)
await self._state_handler.compute_event_context(auth_event)
)
await self._auth_and_persist_event(
missing_auth_event_context = await self._check_event_auth(
origin,
e,
auth_event,
missing_auth_event_context,
claimed_auth_event_map=auth,
)
await self.persist_events_and_notify(
event.room_id, [(auth_event, missing_auth_event_context)]
)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
if auth_event.event_id in event_auth_events:
auth_events[
(auth_event.type, auth_event.state_key)
] = auth_event
except AuthError:
pass
@ -1733,10 +1700,13 @@ class FederationEventHandler:
context: The event context.
backfilled: True if the event was backfilled.
"""
# this method should not be called on outliers (those code paths call
# persist_events_and_notify directly.)
assert not event.internal_metadata.outlier
try:
if (
not event.internal_metadata.is_outlier()
and not backfilled
not backfilled
and not context.rejected
and (await self._store.get_min_depth(event.room_id)) <= event.depth
):

View File

@ -76,9 +76,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.handler = self.homeserver.get_federation_handler()
federation_event_handler = self.homeserver.get_federation_event_handler()
federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
context
)
async def _check_event_auth(
origin,
event,
context,
state=None,
claimed_auth_event_map=None,
backfilled=False,
):
return context
federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus