From e77237b9350a7054657b1a641883a09c6b5d44f3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Dec 2019 17:01:37 +0000 Subject: [PATCH] convert to async: FederationHandler.on_receive_pdu and associated functions: * on_receive_pdu * handle_queued_pdus * get_missing_events_for_pdu --- synapse/handlers/federation.py | 49 +++++++++++++++------------------- tests/test_federation.py | 14 ++++++---- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e54d509b62..01d9c5120e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -165,8 +165,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -176,8 +175,6 @@ class FederationHandler(BaseHandler): pdu (FrozenEvent): received PDU sent_to_us_directly (bool): True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. - - Returns (Deferred): completes with None """ room_id = pdu.room_id @@ -186,7 +183,7 @@ class FederationHandler(BaseHandler): logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers - existing = yield self.store.get_event( + existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -230,7 +227,7 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", @@ -246,12 +243,12 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context(pdu.room_id) + min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -271,7 +268,7 @@ class FederationHandler(BaseHandler): len(missing_prevs), shortstr(missing_prevs), ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", room_id, @@ -280,7 +277,7 @@ class FederationHandler(BaseHandler): ) try: - yield self._get_missing_events_for_pdu( + await self._get_missing_events_for_pdu( origin, pdu, prevs, min_depth ) except Exception as e: @@ -291,7 +288,7 @@ class FederationHandler(BaseHandler): # Update the set of things we've seen after trying to # fetch the missing stuff - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -355,7 +352,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.state_store.get_state_groups_ids(room_id, seen) + ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -372,7 +369,7 @@ class FederationHandler(BaseHandler): "Requesting state at missing prev_event %s", event_id, ) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) with nested_logging_context(p): # note that if any of the missing prevs share missing state or @@ -381,11 +378,11 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = yield self._get_state_for_room(origin, room_id, p) + ) = await self._get_state_for_room(origin, room_id, p) # we want the state *after* p; _get_state_for_room returns the # state *before* p. - remote_event = yield self.federation_client.get_pdu( + remote_event = await self.federation_client.get_pdu( [origin], p, room_version, outlier=True ) @@ -410,7 +407,7 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - state_map = yield resolve_events_with_store( + state_map = await resolve_events_with_store( room_version, state_maps, event_map, @@ -422,7 +419,7 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = yield self.store.get_events( + evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -446,12 +443,11 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu( + await self._process_received_pdu( origin, pdu, state=state, auth_chain=auth_chain ) - @defer.inlineCallbacks - def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ Args: origin (str): Origin of the pdu. Will be called to get the missing events @@ -463,12 +459,12 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room(room_id) + latest = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -532,7 +528,7 @@ class FederationHandler(BaseHandler): # All that said: Let's try increasing the timout to 60s and see what happens. try: - missing_events = yield self.federation_client.get_missing_events( + missing_events = await self.federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -571,7 +567,7 @@ class FederationHandler(BaseHandler): ) with nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warning( @@ -1328,8 +1324,7 @@ class FederationHandler(BaseHandler): return True - @defer.inlineCallbacks - def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. Args: @@ -1345,7 +1340,7 @@ class FederationHandler(BaseHandler): p.room_id, ) with nested_logging_context(p.event_id): - yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e diff --git a/tests/test_federation.py b/tests/test_federation.py index ad165d7295..68684460c6 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,6 +1,6 @@ from mock import Mock -from twisted.internet.defer import maybeDeferred, succeed +from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed from synapse.events import FrozenEvent from synapse.logging.context import LoggingContext @@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase): ) # Send the join, it should return None (which is not an error) - d = self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) ) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) @@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase): ) with LoggingContext(request="lying_event"): - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) ) # Step the reactor, so the database fetches come back