Split `on_receive_pdu` in half (#10640)
Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much.
This commit is contained in:
parent
50af1efe4b
commit
e81d62009e
|
@ -0,0 +1 @@
|
||||||
|
Clean up some of the federation event authentication code for clarity.
|
|
@ -1005,9 +1005,7 @@ class FederationServer(FederationBase):
|
||||||
async with lock:
|
async with lock:
|
||||||
logger.info("handling received PDU: %s", event)
|
logger.info("handling received PDU: %s", event)
|
||||||
try:
|
try:
|
||||||
await self.handler.on_receive_pdu(
|
await self.handler.on_receive_pdu(origin, event)
|
||||||
origin, event, sent_to_us_directly=True
|
|
||||||
)
|
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
# XXX: Ideally we'd inform the remote we failed to process
|
# XXX: Ideally we'd inform the remote we failed to process
|
||||||
# the event, but we can't return an error in the transaction
|
# the event, but we can't return an error in the transaction
|
||||||
|
|
|
@ -203,18 +203,13 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||||
|
|
||||||
async def on_receive_pdu(
|
async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
|
||||||
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
|
"""Process a PDU received via a federation /send/ transaction
|
||||||
) -> None:
|
|
||||||
"""Process a PDU received via a federation /send/ transaction, or
|
|
||||||
via backfill of missing prev_events
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
origin: server which initiated the /send/ transaction. Will
|
origin: server which initiated the /send/ transaction. Will
|
||||||
be used to fetch missing events or state.
|
be used to fetch missing events or state.
|
||||||
pdu: received PDU
|
pdu: received PDU
|
||||||
sent_to_us_directly: True if this event was pushed to us; False if
|
|
||||||
we pulled it as the result of a missing prev_event.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
room_id = pdu.room_id
|
room_id = pdu.room_id
|
||||||
|
@ -276,8 +271,6 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
state = None
|
|
||||||
|
|
||||||
# Check that the event passes auth based on the state at the event. This is
|
# Check that the event passes auth based on the state at the event. This is
|
||||||
# done for events that are to be added to the timeline (non-outliers).
|
# done for events that are to be added to the timeline (non-outliers).
|
||||||
#
|
#
|
||||||
|
@ -285,7 +278,6 @@ class FederationHandler(BaseHandler):
|
||||||
# - Fetching any missing prev events to fill in gaps in the graph
|
# - Fetching any missing prev events to fill in gaps in the graph
|
||||||
# - Fetching state if we have a hole in the graph
|
# - Fetching state if we have a hole in the graph
|
||||||
if not pdu.internal_metadata.is_outlier():
|
if not pdu.internal_metadata.is_outlier():
|
||||||
if sent_to_us_directly:
|
|
||||||
prevs = set(pdu.prev_event_ids())
|
prevs = set(pdu.prev_event_ids())
|
||||||
seen = await self.store.have_events_in_timeline(prevs)
|
seen = await self.store.have_events_in_timeline(prevs)
|
||||||
missing_prevs = prevs - seen
|
missing_prevs = prevs - seen
|
||||||
|
@ -351,17 +343,7 @@ class FederationHandler(BaseHandler):
|
||||||
affected=pdu.event_id,
|
affected=pdu.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
await self._process_received_pdu(origin, pdu, state=None)
|
||||||
state = await self._resolve_state_at_missing_prevs(origin, pdu)
|
|
||||||
|
|
||||||
# A second round of checks for all events. Check that the event passes auth
|
|
||||||
# based on `auth_events`, this allows us to assert that the event would
|
|
||||||
# have been allowed at some point. If an event passes this check its OK
|
|
||||||
# for it to be used as part of a returned `/state` request, as either
|
|
||||||
# a) we received the event as part of the original join and so trust it, or
|
|
||||||
# b) we'll do a state resolution with existing state before it becomes
|
|
||||||
# part of the "current state", which adds more protection.
|
|
||||||
await self._process_received_pdu(origin, pdu, state=state)
|
|
||||||
|
|
||||||
async def _get_missing_events_for_pdu(
|
async def _get_missing_events_for_pdu(
|
||||||
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
|
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
|
||||||
|
@ -461,24 +443,7 @@ class FederationHandler(BaseHandler):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Got %d prev_events", len(missing_events))
|
logger.info("Got %d prev_events", len(missing_events))
|
||||||
|
await self._process_pulled_events(origin, missing_events)
|
||||||
# We want to sort these by depth so we process them and
|
|
||||||
# tell clients about them in order.
|
|
||||||
missing_events.sort(key=lambda x: x.depth)
|
|
||||||
|
|
||||||
for ev in missing_events:
|
|
||||||
logger.info("Handling received prev_event %s", ev)
|
|
||||||
with nested_logging_context(ev.event_id):
|
|
||||||
try:
|
|
||||||
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
|
||||||
except FederationError as e:
|
|
||||||
if e.code == 403:
|
|
||||||
logger.warning(
|
|
||||||
"Received prev_event %s failed history check.",
|
|
||||||
ev.event_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _get_state_for_room(
|
async def _get_state_for_room(
|
||||||
self,
|
self,
|
||||||
|
@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler):
|
||||||
event_infos,
|
event_infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _process_pulled_events(
|
||||||
|
self, origin: str, events: Iterable[EventBase]
|
||||||
|
) -> None:
|
||||||
|
"""Process a batch of events we have pulled from a remote server
|
||||||
|
|
||||||
|
Pulls in any events required to auth the events, persists the received events,
|
||||||
|
and notifies clients, if appropriate.
|
||||||
|
|
||||||
|
Assumes the events have already had their signatures and hashes checked.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
origin: The server we received these events from
|
||||||
|
events: The received events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We want to sort these by depth so we process them and
|
||||||
|
# tell clients about them in order.
|
||||||
|
sorted_events = sorted(events, key=lambda x: x.depth)
|
||||||
|
|
||||||
|
for ev in sorted_events:
|
||||||
|
with nested_logging_context(ev.event_id):
|
||||||
|
await self._process_pulled_event(origin, ev)
|
||||||
|
|
||||||
|
async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
|
||||||
|
"""Process a single event that we have pulled from a remote server
|
||||||
|
|
||||||
|
Pulls in any events required to auth the event, persists the received event,
|
||||||
|
and notifies clients, if appropriate.
|
||||||
|
|
||||||
|
Assumes the event has already had its signatures and hashes checked.
|
||||||
|
|
||||||
|
This is somewhat equivalent to on_receive_pdu, but applies somewhat different
|
||||||
|
logic in the case that we are missing prev_events (in particular, it just
|
||||||
|
requests the state at that point, rather than triggering a get_missing_events) -
|
||||||
|
so is appropriate when we have pulled the event from a remote server, rather
|
||||||
|
than having it pushed to us.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
origin: The server we received this event from
|
||||||
|
events: The received event
|
||||||
|
"""
|
||||||
|
logger.info("Processing pulled event %s", event)
|
||||||
|
|
||||||
|
# these should not be outliers.
|
||||||
|
assert not event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
|
event_id = event.event_id
|
||||||
|
|
||||||
|
existing = await self.store.get_event(
|
||||||
|
event_id, allow_none=True, allow_rejected=True
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
if not existing.internal_metadata.is_outlier():
|
||||||
|
logger.info(
|
||||||
|
"Ignoring received event %s which we have already seen",
|
||||||
|
event_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info("De-outliering event %s", event_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._sanity_check_event(event)
|
||||||
|
except SynapseError as err:
|
||||||
|
logger.warning("Event %s failed sanity check: %s", event_id, err)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
state = await self._resolve_state_at_missing_prevs(origin, event)
|
||||||
|
await self._process_received_pdu(origin, event, state=state)
|
||||||
|
except FederationError as e:
|
||||||
|
if e.code == 403:
|
||||||
|
logger.warning("Pulled event %s failed history check.", event_id)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
async def _resolve_state_at_missing_prevs(
|
async def _resolve_state_at_missing_prevs(
|
||||||
self, dest: str, event: EventBase
|
self, dest: str, event: EventBase
|
||||||
) -> Optional[Iterable[EventBase]]:
|
) -> Optional[Iterable[EventBase]]:
|
||||||
|
@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler):
|
||||||
p,
|
p,
|
||||||
)
|
)
|
||||||
with nested_logging_context(p.event_id):
|
with nested_logging_context(p.event_id):
|
||||||
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
await self.on_receive_pdu(origin, p)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
|
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
|
||||||
|
|
|
@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Send the join, it should return None (which is not an error)
|
# Send the join, it should return None (which is not an error)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.get_success(
|
self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
|
||||||
self.handler.on_receive_pdu(
|
|
||||||
"test.serv", join_event, sent_to_us_directly=True
|
|
||||||
)
|
|
||||||
),
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
with LoggingContext("test-context"):
|
with LoggingContext("test-context"):
|
||||||
failure = self.get_failure(
|
failure = self.get_failure(
|
||||||
self.handler.on_receive_pdu(
|
self.handler.on_receive_pdu("test.serv", lying_event),
|
||||||
"test.serv", lying_event, sent_to_us_directly=True
|
|
||||||
),
|
|
||||||
FederationError,
|
FederationError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue