Split `on_receive_pdu` in half (#10640)

Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much.
This commit is contained in:
Richard van der Hoff 2021-08-19 18:05:12 +01:00 committed by GitHub
parent 50af1efe4b
commit e81d62009e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 146 additions and 113 deletions

1
changelog.d/10640.misc Normal file
View File

@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.

View File

@ -1005,9 +1005,7 @@ class FederationServer(FederationBase):
async with lock: async with lock:
logger.info("handling received PDU: %s", event) logger.info("handling received PDU: %s", event)
try: try:
await self.handler.on_receive_pdu( await self.handler.on_receive_pdu(origin, event)
origin, event, sent_to_us_directly=True
)
except FederationError as e: except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process # XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction # the event, but we can't return an error in the transaction

View File

@ -203,18 +203,13 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu( async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False """Process a PDU received via a federation /send/ transaction
) -> None:
"""Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args: Args:
origin: server which initiated the /send/ transaction. Will origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state. be used to fetch missing events or state.
pdu: received PDU pdu: received PDU
sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
""" """
room_id = pdu.room_id room_id = pdu.room_id
@ -276,8 +271,6 @@ class FederationHandler(BaseHandler):
) )
return None return None
state = None
# Check that the event passes auth based on the state at the event. This is # Check that the event passes auth based on the state at the event. This is
# done for events that are to be added to the timeline (non-outliers). # done for events that are to be added to the timeline (non-outliers).
# #
@ -285,7 +278,6 @@ class FederationHandler(BaseHandler):
# - Fetching any missing prev events to fill in gaps in the graph # - Fetching any missing prev events to fill in gaps in the graph
# - Fetching state if we have a hole in the graph # - Fetching state if we have a hole in the graph
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
if sent_to_us_directly:
prevs = set(pdu.prev_event_ids()) prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs) seen = await self.store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen missing_prevs = prevs - seen
@ -351,17 +343,7 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id, affected=pdu.event_id,
) )
else: await self._process_received_pdu(origin, pdu, state=None)
state = await self._resolve_state_at_missing_prevs(origin, pdu)
# A second round of checks for all events. Check that the event passes auth
# based on `auth_events`, this allows us to assert that the event would
# have been allowed at some point. If an event passes this check its OK
# for it to be used as part of a returned `/state` request, as either
# a) we received the event as part of the original join and so trust it, or
# b) we'll do a state resolution with existing state before it becomes
# part of the "current state", which adds more protection.
await self._process_received_pdu(origin, pdu, state=state)
async def _get_missing_events_for_pdu( async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
@ -461,24 +443,7 @@ class FederationHandler(BaseHandler):
return return
logger.info("Got %d prev_events", len(missing_events)) logger.info("Got %d prev_events", len(missing_events))
await self._process_pulled_events(origin, missing_events)
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for ev in missing_events:
logger.info("Handling received prev_event %s", ev)
with nested_logging_context(ev.event_id):
try:
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
"Received prev_event %s failed history check.",
ev.event_id,
)
else:
raise
async def _get_state_for_room( async def _get_state_for_room(
self, self,
@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler):
event_infos, event_infos,
) )
async def _process_pulled_events(
self, origin: str, events: Iterable[EventBase]
) -> None:
"""Process a batch of events we have pulled from a remote server
Pulls in any events required to auth the events, persists the received events,
and notifies clients, if appropriate.
Assumes the events have already had their signatures and hashes checked.
Params:
origin: The server we received these events from
events: The received events.
"""
# We want to sort these by depth so we process them and
# tell clients about them in order.
sorted_events = sorted(events, key=lambda x: x.depth)
for ev in sorted_events:
with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev)
async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
"""Process a single event that we have pulled from a remote server
Pulls in any events required to auth the event, persists the received event,
and notifies clients, if appropriate.
Assumes the event has already had its signatures and hashes checked.
This is somewhat equivalent to on_receive_pdu, but applies somewhat different
logic in the case that we are missing prev_events (in particular, it just
requests the state at that point, rather than triggering a get_missing_events) -
so is appropriate when we have pulled the event from a remote server, rather
than having it pushed to us.
Params:
origin: The server we received this event from
events: The received event
"""
logger.info("Processing pulled event %s", event)
# these should not be outliers.
assert not event.internal_metadata.is_outlier()
event_id = event.event_id
existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if existing:
if not existing.internal_metadata.is_outlier():
logger.info(
"Ignoring received event %s which we have already seen",
event_id,
)
return
logger.info("De-outliering event %s", event_id)
try:
self._sanity_check_event(event)
except SynapseError as err:
logger.warning("Event %s failed sanity check: %s", event_id, err)
return
try:
state = await self._resolve_state_at_missing_prevs(origin, event)
await self._process_received_pdu(origin, event, state=state)
except FederationError as e:
if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id)
else:
raise
async def _resolve_state_at_missing_prevs( async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]: ) -> Optional[Iterable[EventBase]]:
@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler):
p, p,
) )
with nested_logging_context(p.event_id): with nested_logging_context(p.event_id):
await self.on_receive_pdu(origin, p, sent_to_us_directly=True) await self.on_receive_pdu(origin, p)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e "Error handling queued PDU %s from %s: %s", p.event_id, origin, e

View File

@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Send the join, it should return None (which is not an error) # Send the join, it should return None (which is not an error)
self.assertEqual( self.assertEqual(
self.get_success( self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True
)
),
None, None,
) )
@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
with LoggingContext("test-context"): with LoggingContext("test-context"):
failure = self.get_failure( failure = self.get_failure(
self.handler.on_receive_pdu( self.handler.on_receive_pdu("test.serv", lying_event),
"test.serv", lying_event, sent_to_us_directly=True
),
FederationError, FederationError,
) )