convert to async: FederationHandler.on_receive_pdu
and associated functions: * on_receive_pdu * handle_queued_pdus * get_missing_events_for_pdu
This commit is contained in:
parent
7712e751b8
commit
e77237b935
|
@ -165,8 +165,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
|
||||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
|
||||
""" Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
|
||||
|
@ -176,8 +175,6 @@ class FederationHandler(BaseHandler):
|
|||
pdu (FrozenEvent): received PDU
|
||||
sent_to_us_directly (bool): True if this event was pushed to us; False if
|
||||
we pulled it as the result of a missing prev_event.
|
||||
|
||||
Returns (Deferred): completes with None
|
||||
"""
|
||||
|
||||
room_id = pdu.room_id
|
||||
|
@ -186,7 +183,7 @@ class FederationHandler(BaseHandler):
|
|||
logger.info("handling received PDU: %s", pdu)
|
||||
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self.store.get_event(
|
||||
existing = await self.store.get_event(
|
||||
event_id, allow_none=True, allow_rejected=True
|
||||
)
|
||||
|
||||
|
@ -230,7 +227,7 @@ class FederationHandler(BaseHandler):
|
|||
#
|
||||
# Note that if we were never in the room then we would have already
|
||||
# dropped the event, since we wouldn't know the room version.
|
||||
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
|
||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
||||
if not is_in_room:
|
||||
logger.info(
|
||||
"[%s %s] Ignoring PDU from %s as we're not in the room",
|
||||
|
@ -246,12 +243,12 @@ class FederationHandler(BaseHandler):
|
|||
# Get missing pdus if necessary.
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = yield self.get_min_depth_for_context(pdu.room_id)
|
||||
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
||||
|
||||
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
|
||||
|
||||
prevs = set(pdu.prev_event_ids())
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if min_depth and pdu.depth < min_depth:
|
||||
# This is so that we don't notify the user about this
|
||||
|
@ -271,7 +268,7 @@ class FederationHandler(BaseHandler):
|
|||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
logger.info(
|
||||
"[%s %s] Acquired room lock to fetch %d missing prev_events",
|
||||
room_id,
|
||||
|
@ -280,7 +277,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
try:
|
||||
yield self._get_missing_events_for_pdu(
|
||||
await self._get_missing_events_for_pdu(
|
||||
origin, pdu, prevs, min_depth
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -291,7 +288,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# Update the set of things we've seen after trying to
|
||||
# fetch the missing stuff
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
logger.info(
|
||||
|
@ -355,7 +352,7 @@ class FederationHandler(BaseHandler):
|
|||
event_map = {event_id: pdu}
|
||||
try:
|
||||
# Get the state of the events we know about
|
||||
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
|
||||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(
|
||||
|
@ -372,7 +369,7 @@ class FederationHandler(BaseHandler):
|
|||
"Requesting state at missing prev_event %s", event_id,
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
with nested_logging_context(p):
|
||||
# note that if any of the missing prevs share missing state or
|
||||
|
@ -381,11 +378,11 @@ class FederationHandler(BaseHandler):
|
|||
(
|
||||
remote_state,
|
||||
got_auth_chain,
|
||||
) = yield self._get_state_for_room(origin, room_id, p)
|
||||
) = await self._get_state_for_room(origin, room_id, p)
|
||||
|
||||
# we want the state *after* p; _get_state_for_room returns the
|
||||
# state *before* p.
|
||||
remote_event = yield self.federation_client.get_pdu(
|
||||
remote_event = await self.federation_client.get_pdu(
|
||||
[origin], p, room_version, outlier=True
|
||||
)
|
||||
|
||||
|
@ -410,7 +407,7 @@ class FederationHandler(BaseHandler):
|
|||
for x in remote_state:
|
||||
event_map[x.event_id] = x
|
||||
|
||||
state_map = yield resolve_events_with_store(
|
||||
state_map = await resolve_events_with_store(
|
||||
room_version,
|
||||
state_maps,
|
||||
event_map,
|
||||
|
@ -422,7 +419,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# First though we need to fetch all the events that are in
|
||||
# state_map, so we can build up the state below.
|
||||
evs = yield self.store.get_events(
|
||||
evs = await self.store.get_events(
|
||||
list(state_map.values()),
|
||||
get_prev_content=False,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
|
@ -446,12 +443,11 @@ class FederationHandler(BaseHandler):
|
|||
affected=event_id,
|
||||
)
|
||||
|
||||
yield self._process_received_pdu(
|
||||
await self._process_received_pdu(
|
||||
origin, pdu, state=state, auth_chain=auth_chain
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
"""
|
||||
Args:
|
||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||
|
@ -463,12 +459,12 @@ class FederationHandler(BaseHandler):
|
|||
room_id = pdu.room_id
|
||||
event_id = pdu.event_id
|
||||
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
return
|
||||
|
||||
latest = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
|
@ -532,7 +528,7 @@ class FederationHandler(BaseHandler):
|
|||
# All that said: Let's try increasing the timout to 60s and see what happens.
|
||||
|
||||
try:
|
||||
missing_events = yield self.federation_client.get_missing_events(
|
||||
missing_events = await self.federation_client.get_missing_events(
|
||||
origin,
|
||||
room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
|
@ -571,7 +567,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
with nested_logging_context(ev.event_id):
|
||||
try:
|
||||
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
logger.warning(
|
||||
|
@ -1328,8 +1324,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_queued_pdus(self, room_queue):
|
||||
async def _handle_queued_pdus(self, room_queue):
|
||||
"""Process PDUs which got queued up while we were busy send_joining.
|
||||
|
||||
Args:
|
||||
|
@ -1345,7 +1340,7 @@ class FederationHandler(BaseHandler):
|
|||
p.room_id,
|
||||
)
|
||||
with nested_logging_context(p.event_id):
|
||||
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from mock import Mock
|
||||
|
||||
from twisted.internet.defer import maybeDeferred, succeed
|
||||
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.logging.context import LoggingContext
|
||||
|
@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
# Send the join, it should return None (which is not an error)
|
||||
d = self.handler.on_receive_pdu(
|
||||
"test.serv", join_event, sent_to_us_directly=True
|
||||
d = ensureDeferred(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", join_event, sent_to_us_directly=True
|
||||
)
|
||||
)
|
||||
self.reactor.advance(1)
|
||||
self.assertEqual(self.successResultOf(d), None)
|
||||
|
@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
with LoggingContext(request="lying_event"):
|
||||
d = self.handler.on_receive_pdu(
|
||||
"test.serv", lying_event, sent_to_us_directly=True
|
||||
d = ensureDeferred(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", lying_event, sent_to_us_directly=True
|
||||
)
|
||||
)
|
||||
|
||||
# Step the reactor, so the database fetches come back
|
||||
|
|
Loading…
Reference in New Issue