Merge pull request #6517 from matrix-org/rav/event_auth/13
Port some of FederationHandler to async/await
This commit is contained in:
commit
894d2addac
|
@ -0,0 +1 @@
|
|||
Port some of FederationHandler to async/await.
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Dict, Iterable, Optional, Sequence, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import six
|
||||
from six import iteritems, itervalues
|
||||
|
@ -165,8 +165,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
|
||||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
|
||||
""" Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
|
||||
|
@ -176,8 +175,6 @@ class FederationHandler(BaseHandler):
|
|||
pdu (FrozenEvent): received PDU
|
||||
sent_to_us_directly (bool): True if this event was pushed to us; False if
|
||||
we pulled it as the result of a missing prev_event.
|
||||
|
||||
Returns (Deferred): completes with None
|
||||
"""
|
||||
|
||||
room_id = pdu.room_id
|
||||
|
@ -186,7 +183,7 @@ class FederationHandler(BaseHandler):
|
|||
logger.info("handling received PDU: %s", pdu)
|
||||
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self.store.get_event(
|
||||
existing = await self.store.get_event(
|
||||
event_id, allow_none=True, allow_rejected=True
|
||||
)
|
||||
|
||||
|
@ -230,7 +227,7 @@ class FederationHandler(BaseHandler):
|
|||
#
|
||||
# Note that if we were never in the room then we would have already
|
||||
# dropped the event, since we wouldn't know the room version.
|
||||
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
|
||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
||||
if not is_in_room:
|
||||
logger.info(
|
||||
"[%s %s] Ignoring PDU from %s as we're not in the room",
|
||||
|
@ -246,12 +243,12 @@ class FederationHandler(BaseHandler):
|
|||
# Get missing pdus if necessary.
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = yield self.get_min_depth_for_context(pdu.room_id)
|
||||
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
||||
|
||||
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
|
||||
|
||||
prevs = set(pdu.prev_event_ids())
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if min_depth and pdu.depth < min_depth:
|
||||
# This is so that we don't notify the user about this
|
||||
|
@ -271,7 +268,7 @@ class FederationHandler(BaseHandler):
|
|||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
logger.info(
|
||||
"[%s %s] Acquired room lock to fetch %d missing prev_events",
|
||||
room_id,
|
||||
|
@ -280,7 +277,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
try:
|
||||
yield self._get_missing_events_for_pdu(
|
||||
await self._get_missing_events_for_pdu(
|
||||
origin, pdu, prevs, min_depth
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -291,7 +288,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# Update the set of things we've seen after trying to
|
||||
# fetch the missing stuff
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
logger.info(
|
||||
|
@ -355,7 +352,7 @@ class FederationHandler(BaseHandler):
|
|||
event_map = {event_id: pdu}
|
||||
try:
|
||||
# Get the state of the events we know about
|
||||
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
|
||||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(
|
||||
|
@ -372,7 +369,7 @@ class FederationHandler(BaseHandler):
|
|||
"Requesting state at missing prev_event %s", event_id,
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
with nested_logging_context(p):
|
||||
# note that if any of the missing prevs share missing state or
|
||||
|
@ -381,11 +378,11 @@ class FederationHandler(BaseHandler):
|
|||
(
|
||||
remote_state,
|
||||
got_auth_chain,
|
||||
) = yield self._get_state_for_room(origin, room_id, p)
|
||||
) = await self._get_state_for_room(origin, room_id, p)
|
||||
|
||||
# we want the state *after* p; _get_state_for_room returns the
|
||||
# state *before* p.
|
||||
remote_event = yield self.federation_client.get_pdu(
|
||||
remote_event = await self.federation_client.get_pdu(
|
||||
[origin], p, room_version, outlier=True
|
||||
)
|
||||
|
||||
|
@ -410,7 +407,7 @@ class FederationHandler(BaseHandler):
|
|||
for x in remote_state:
|
||||
event_map[x.event_id] = x
|
||||
|
||||
state_map = yield resolve_events_with_store(
|
||||
state_map = await resolve_events_with_store(
|
||||
room_version,
|
||||
state_maps,
|
||||
event_map,
|
||||
|
@ -422,7 +419,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# First though we need to fetch all the events that are in
|
||||
# state_map, so we can build up the state below.
|
||||
evs = yield self.store.get_events(
|
||||
evs = await self.store.get_events(
|
||||
list(state_map.values()),
|
||||
get_prev_content=False,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
|
@ -446,12 +443,11 @@ class FederationHandler(BaseHandler):
|
|||
affected=event_id,
|
||||
)
|
||||
|
||||
yield self._process_received_pdu(
|
||||
await self._process_received_pdu(
|
||||
origin, pdu, state=state, auth_chain=auth_chain
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
"""
|
||||
Args:
|
||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||
|
@ -463,12 +459,12 @@ class FederationHandler(BaseHandler):
|
|||
room_id = pdu.room_id
|
||||
event_id = pdu.event_id
|
||||
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
return
|
||||
|
||||
latest = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
|
@ -532,7 +528,7 @@ class FederationHandler(BaseHandler):
|
|||
# All that said: Let's try increasing the timout to 60s and see what happens.
|
||||
|
||||
try:
|
||||
missing_events = yield self.federation_client.get_missing_events(
|
||||
missing_events = await self.federation_client.get_missing_events(
|
||||
origin,
|
||||
room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
|
@ -571,7 +567,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
with nested_logging_context(ev.event_id):
|
||||
try:
|
||||
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
logger.warning(
|
||||
|
@ -583,30 +579,30 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _get_state_for_room(self, destination, room_id, event_id):
|
||||
async def _get_state_for_room(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
) -> Tuple[List[EventBase], List[EventBase]]:
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
destination (str): The remote homeserver to query for the state.
|
||||
room_id (str): The id of the room we're interested in.
|
||||
event_id (str): The id of the event we want the state at.
|
||||
destination:: The remote homeserver to query for the state.
|
||||
room_id: The id of the room we're interested in.
|
||||
event_id: The id of the event we want the state at.
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List[EventBase], List[EventBase]]]:
|
||||
A list of events in the state, and a list of events in the auth chain
|
||||
for the given event.
|
||||
"""
|
||||
(
|
||||
state_event_ids,
|
||||
auth_event_ids,
|
||||
) = yield self.federation_client.get_room_state_ids(
|
||||
) = await self.federation_client.get_room_state_ids(
|
||||
destination, room_id, event_id=event_id
|
||||
)
|
||||
|
||||
desired_events = set(state_event_ids + auth_event_ids)
|
||||
event_map = yield self._get_events_from_store_or_dest(
|
||||
event_map = await self._get_events_from_store_or_dest(
|
||||
destination, room_id, desired_events
|
||||
)
|
||||
|
||||
|
@ -625,20 +621,20 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return pdus, auth_chain
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
|
||||
async def _get_events_from_store_or_dest(
|
||||
self, destination: str, room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[str, EventBase]:
|
||||
"""Fetch events from a remote destination, checking if we already have them.
|
||||
|
||||
Args:
|
||||
destination (str)
|
||||
room_id (str)
|
||||
event_ids (Iterable[str])
|
||||
destination
|
||||
room_id
|
||||
event_ids
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, EventBase]]: A deferred resolving to a map
|
||||
from event_id to event
|
||||
map from event_id to event
|
||||
"""
|
||||
fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
|
||||
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
|
||||
|
||||
missing_events = set(event_ids) - fetched_events.keys()
|
||||
|
||||
|
@ -651,7 +647,7 @@ class FederationHandler(BaseHandler):
|
|||
event_ids,
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
# XXX 20 requests at once? really?
|
||||
for batch in batch_iter(missing_events, 20):
|
||||
|
@ -665,7 +661,7 @@ class FederationHandler(BaseHandler):
|
|||
for e_id in batch
|
||||
]
|
||||
|
||||
res = yield make_deferred_yieldable(
|
||||
res = await make_deferred_yieldable(
|
||||
defer.DeferredList(deferreds, consumeErrors=True)
|
||||
)
|
||||
for success, result in res:
|
||||
|
@ -674,8 +670,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return fetched_events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_received_pdu(self, origin, event, state, auth_chain):
|
||||
async def _process_received_pdu(self, origin, event, state, auth_chain):
|
||||
""" Called when we have a new pdu. We need to do auth checks and put it
|
||||
through the StateHandler.
|
||||
"""
|
||||
|
@ -690,7 +685,7 @@ class FederationHandler(BaseHandler):
|
|||
if auth_chain:
|
||||
event_ids |= {e.event_id for e in auth_chain}
|
||||
|
||||
seen_ids = yield self.store.have_seen_events(event_ids)
|
||||
seen_ids = await self.store.have_seen_events(event_ids)
|
||||
|
||||
if state and auth_chain is not None:
|
||||
# If we have any state or auth_chain given to us by the replication
|
||||
|
@ -717,18 +712,18 @@ class FederationHandler(BaseHandler):
|
|||
event_id,
|
||||
[e.event.event_id for e in event_infos],
|
||||
)
|
||||
yield self._handle_new_events(origin, event_infos)
|
||||
await self._handle_new_events(origin, event_infos)
|
||||
|
||||
try:
|
||||
context = yield self._handle_new_event(origin, event, state=state)
|
||||
context = await self._handle_new_event(origin, event, state=state)
|
||||
except AuthError as e:
|
||||
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
|
||||
|
||||
room = yield self.store.get_room(room_id)
|
||||
room = await self.store.get_room(room_id)
|
||||
|
||||
if not room:
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
await self.store.store_room(
|
||||
room_id=room_id, room_creator_user_id="", is_public=False
|
||||
)
|
||||
except StoreError:
|
||||
|
@ -741,11 +736,11 @@ class FederationHandler(BaseHandler):
|
|||
# changing their profile info.
|
||||
newly_joined = True
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
prev_state_ids = await context.get_prev_state_ids(self.store)
|
||||
|
||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||
if prev_state_id:
|
||||
prev_state = yield self.store.get_event(
|
||||
prev_state = await self.store.get_event(
|
||||
prev_state_id, allow_none=True
|
||||
)
|
||||
if prev_state and prev_state.membership == Membership.JOIN:
|
||||
|
@ -753,11 +748,10 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
if newly_joined:
|
||||
user = UserID.from_string(event.state_key)
|
||||
yield self.user_joined_room(user, room_id)
|
||||
await self.user_joined_room(user, room_id)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities):
|
||||
async def backfill(self, dest, room_id, limit, extremities):
|
||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
This will attempt to get more events from the remote. If the other side
|
||||
|
@ -774,9 +768,9 @@ class FederationHandler(BaseHandler):
|
|||
if dest == self.server_name:
|
||||
raise SynapseError(400, "Can't backfill from self.")
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
events = yield self.federation_client.backfill(
|
||||
events = await self.federation_client.backfill(
|
||||
dest, room_id, limit=limit, extremities=extremities
|
||||
)
|
||||
|
||||
|
@ -791,7 +785,7 @@ class FederationHandler(BaseHandler):
|
|||
# self._sanity_check_event(ev)
|
||||
|
||||
# Don't bother processing events we already have.
|
||||
seen_events = yield self.store.have_events_in_timeline(
|
||||
seen_events = await self.store.have_events_in_timeline(
|
||||
set(e.event_id for e in events)
|
||||
)
|
||||
|
||||
|
@ -814,7 +808,7 @@ class FederationHandler(BaseHandler):
|
|||
state_events = {}
|
||||
events_to_state = {}
|
||||
for e_id in edges:
|
||||
state, auth = yield self._get_state_for_room(
|
||||
state, auth = await self._get_state_for_room(
|
||||
destination=dest, room_id=room_id, event_id=e_id
|
||||
)
|
||||
auth_events.update({a.event_id: a for a in auth})
|
||||
|
@ -839,7 +833,7 @@ class FederationHandler(BaseHandler):
|
|||
# We repeatedly do this until we stop finding new auth events.
|
||||
while missing_auth - failed_to_fetch:
|
||||
logger.info("Missing auth for backfill: %r", missing_auth)
|
||||
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
|
||||
ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
|
||||
auth_events.update(ret_events)
|
||||
|
||||
required_auth.update(
|
||||
|
@ -853,7 +847,7 @@ class FederationHandler(BaseHandler):
|
|||
missing_auth - failed_to_fetch,
|
||||
)
|
||||
|
||||
results = yield make_deferred_yieldable(
|
||||
results = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
|
@ -880,7 +874,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
failed_to_fetch = missing_auth - set(auth_events)
|
||||
|
||||
seen_events = yield self.store.have_seen_events(
|
||||
seen_events = await self.store.have_seen_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
|
||||
|
@ -942,7 +936,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
)
|
||||
|
||||
yield self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
await self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
|
||||
# Step 2: Persist the rest of the events in the chunk one by one
|
||||
events.sort(key=lambda e: e.depth)
|
||||
|
@ -958,16 +952,15 @@ class FederationHandler(BaseHandler):
|
|||
# We store these one at a time since each event depends on the
|
||||
# previous to work out the state.
|
||||
# TODO: We can probably do something more clever here.
|
||||
yield self._handle_new_event(dest, event, backfilled=True)
|
||||
await self._handle_new_event(dest, event, backfilled=True)
|
||||
|
||||
return events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_backfill(self, room_id, current_depth):
|
||||
async def maybe_backfill(self, room_id, current_depth):
|
||||
"""Checks the database to see if we should backfill before paginating,
|
||||
and if so do.
|
||||
"""
|
||||
extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
|
||||
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
|
||||
|
||||
if not extremities:
|
||||
logger.debug("Not backfilling as no extremeties found.")
|
||||
|
@ -999,9 +992,9 @@ class FederationHandler(BaseHandler):
|
|||
# state *before* the event, ignoring the special casing certain event
|
||||
# types have.
|
||||
|
||||
forward_events = yield self.store.get_successor_events(list(extremities))
|
||||
forward_events = await self.store.get_successor_events(list(extremities))
|
||||
|
||||
extremities_events = yield self.store.get_events(
|
||||
extremities_events = await self.store.get_events(
|
||||
forward_events,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
get_prev_content=False,
|
||||
|
@ -1009,7 +1002,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# We set `check_history_visibility_only` as we might otherwise get false
|
||||
# positives from users having been erased.
|
||||
filtered_extremities = yield filter_events_for_server(
|
||||
filtered_extremities = await filter_events_for_server(
|
||||
self.storage,
|
||||
self.server_name,
|
||||
list(extremities_events.values()),
|
||||
|
@ -1039,7 +1032,7 @@ class FederationHandler(BaseHandler):
|
|||
# First we try hosts that are already in the room
|
||||
# TODO: HEURISTIC ALERT.
|
||||
|
||||
curr_state = yield self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
|
||||
def get_domains_from_state(state):
|
||||
"""Get joined domains from state
|
||||
|
@ -1078,12 +1071,11 @@ class FederationHandler(BaseHandler):
|
|||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def try_backfill(domains):
|
||||
async def try_backfill(domains):
|
||||
# TODO: Should we try multiple of these at a time?
|
||||
for dom in domains:
|
||||
try:
|
||||
yield self.backfill(
|
||||
await self.backfill(
|
||||
dom, room_id, limit=100, extremities=extremities
|
||||
)
|
||||
# If this succeeded then we probably already have the
|
||||
|
@ -1114,7 +1106,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return False
|
||||
|
||||
success = yield try_backfill(likely_domains)
|
||||
success = await try_backfill(likely_domains)
|
||||
if success:
|
||||
return True
|
||||
|
||||
|
@ -1128,7 +1120,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
|
||||
states = yield make_deferred_yieldable(
|
||||
states = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
|
||||
)
|
||||
|
@ -1138,7 +1130,7 @@ class FederationHandler(BaseHandler):
|
|||
# event_ids.
|
||||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
state_map = yield self.store.get_events(
|
||||
state_map = await self.store.get_events(
|
||||
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
@ -1154,7 +1146,7 @@ class FederationHandler(BaseHandler):
|
|||
for e_id, _ in sorted_extremeties_tuple:
|
||||
likely_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
success = yield try_backfill(
|
||||
success = await try_backfill(
|
||||
[dom for dom, _ in likely_domains if dom not in tried_domains]
|
||||
)
|
||||
if success:
|
||||
|
@ -1331,8 +1323,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_queued_pdus(self, room_queue):
|
||||
async def _handle_queued_pdus(self, room_queue):
|
||||
"""Process PDUs which got queued up while we were busy send_joining.
|
||||
|
||||
Args:
|
||||
|
@ -1348,7 +1339,7 @@ class FederationHandler(BaseHandler):
|
|||
p.room_id,
|
||||
)
|
||||
with nested_logging_context(p.event_id):
|
||||
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
|
||||
|
@ -2907,7 +2898,7 @@ class FederationHandler(BaseHandler):
|
|||
room_id=room_id, user_id=user.to_string(), change="joined"
|
||||
)
|
||||
else:
|
||||
return user_joined_room(self.distributor, user, room_id)
|
||||
return defer.succeed(user_joined_room(self.distributor, user, room_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_complexity(self, remote_room_hosts, room_id):
|
||||
|
|
|
@ -280,8 +280,7 @@ class PaginationHandler(object):
|
|||
|
||||
await self.storage.purge_events.purge_room(room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(
|
||||
async def get_messages(
|
||||
self,
|
||||
requester,
|
||||
room_id=None,
|
||||
|
@ -307,7 +306,7 @@ class PaginationHandler(object):
|
|||
room_token = pagin_config.from_token.room_key
|
||||
else:
|
||||
pagin_config.from_token = (
|
||||
yield self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
await self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
)
|
||||
room_token = pagin_config.from_token.room_key
|
||||
|
||||
|
@ -319,11 +318,11 @@ class PaginationHandler(object):
|
|||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (yield self.pagination_lock.read(room_id)):
|
||||
with (await self.pagination_lock.read(room_id)):
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
|
||||
if source_config.direction == "b":
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
|
@ -331,7 +330,7 @@ class PaginationHandler(object):
|
|||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = yield self.store.get_max_topological_token(
|
||||
max_topo = await self.store.get_max_topological_token(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
|
@ -339,18 +338,18 @@ class PaginationHandler(object):
|
|||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
leave_token = await self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
await self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, max_topo
|
||||
)
|
||||
|
||||
events, next_key = yield self.store.paginate_room_events(
|
||||
events, next_key = await self.store.paginate_room_events(
|
||||
room_id=room_id,
|
||||
from_key=source_config.from_key,
|
||||
to_key=source_config.to_key,
|
||||
|
@ -365,7 +364,7 @@ class PaginationHandler(object):
|
|||
if event_filter:
|
||||
events = event_filter.filter(events)
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
||||
)
|
||||
|
||||
|
@ -385,19 +384,19 @@ class PaginationHandler(object):
|
|||
(EventTypes.Member, event.sender) for event in events
|
||||
)
|
||||
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
if state_ids:
|
||||
state = yield self.store.get_events(list(state_ids.values()))
|
||||
state = await self.store.get_events(list(state_ids.values()))
|
||||
state = state.values()
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunk = {
|
||||
"chunk": (
|
||||
yield self._event_serializer.serialize_events(
|
||||
await self._event_serializer.serialize_events(
|
||||
events, time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
|
@ -406,7 +405,7 @@ class PaginationHandler(object):
|
|||
}
|
||||
|
||||
if state:
|
||||
chunk["state"] = yield self._event_serializer.serialize_events(
|
||||
chunk["state"] = await self._event_serializer.serialize_events(
|
||||
state, time_now, as_client_event=as_client_event
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from mock import Mock
|
||||
|
||||
from twisted.internet.defer import maybeDeferred, succeed
|
||||
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.logging.context import LoggingContext
|
||||
|
@ -70,9 +70,11 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
# Send the join, it should return None (which is not an error)
|
||||
d = self.handler.on_receive_pdu(
|
||||
d = ensureDeferred(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", join_event, sent_to_us_directly=True
|
||||
)
|
||||
)
|
||||
self.reactor.advance(1)
|
||||
self.assertEqual(self.successResultOf(d), None)
|
||||
|
||||
|
@ -119,9 +121,11 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
with LoggingContext(request="lying_event"):
|
||||
d = self.handler.on_receive_pdu(
|
||||
d = ensureDeferred(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", lying_event, sent_to_us_directly=True
|
||||
)
|
||||
)
|
||||
|
||||
# Step the reactor, so the database fetches come back
|
||||
self.reactor.advance(1)
|
||||
|
|
Loading…
Reference in New Issue