Port synapse.handlers.initial_sync to async/await

This commit is contained in:
Erik Johnston 2019-12-09 13:46:45 +00:00
parent f166a8d1f5
commit a1f8ea9051
1 changed files with 44 additions and 52 deletions

View File

@ -89,8 +89,7 @@ class InitialSyncHandler(BaseHandler):
include_archived, include_archived,
) )
@defer.inlineCallbacks async def _snapshot_all_rooms(
def _snapshot_all_rooms(
self, self,
user_id=None, user_id=None,
pagin_config=None, pagin_config=None,
@ -102,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived: if include_archived:
memberships.append(Membership.LEAVE) memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = await self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships user_id=user_id, membership_list=memberships
) )
@ -110,33 +109,32 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = [] rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token() now_token = await self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"] presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token) pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows( presence, _ = await presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None user, pagination_config.get_source_config("presence"), None
) )
receipt_stream = self.hs.get_event_sources().sources["receipt"] receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows( receipt, _ = await receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None user, pagination_config.get_source_config("receipt"), None
) )
tags_by_room = yield self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = yield self.store.get_account_data_for_user( account_data, account_data_by_room = await self.store.get_account_data_for_user(
user_id user_id
) )
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = await self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
if limit is None: if limit is None:
limit = 10 limit = 10
@defer.inlineCallbacks async def handle_room(event):
def handle_room(event):
d = { d = {
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
@ -149,8 +147,8 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
d["inviter"] = event.sender d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id) invite_event = await self.store.get_event(event.event_id)
d["invite"] = yield self._event_serializer.serialize_event( d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event invite_event, time_now, as_client_event
) )
@ -174,7 +172,7 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id] lambda states: states[event.event_id]
) )
(messages, token), current_state = yield make_deferred_yieldable( (messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
@ -188,7 +186,7 @@ class InitialSyncHandler(BaseHandler):
) )
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield filter_events_for_client( messages = await filter_events_for_client(
self.storage, user_id, messages self.storage, user_id, messages
) )
@ -198,7 +196,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = { d["messages"] = {
"chunk": ( "chunk": (
yield self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event messages, time_now=time_now, as_client_event=as_client_event
) )
), ),
@ -206,7 +204,7 @@ class InitialSyncHandler(BaseHandler):
"end": end_token.to_string(), "end": end_token.to_string(),
} }
d["state"] = yield self._event_serializer.serialize_events( d["state"] = await self._event_serializer.serialize_events(
current_state.values(), current_state.values(),
time_now=time_now, time_now=time_now,
as_client_event=as_client_event, as_client_event=as_client_event,
@ -229,7 +227,7 @@ class InitialSyncHandler(BaseHandler):
except Exception: except Exception:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10) await concurrently_execute(handle_room, room_list, 10)
account_data_events = [] account_data_events = []
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():
@ -253,8 +251,7 @@ class InitialSyncHandler(BaseHandler):
return ret return ret
@defer.inlineCallbacks async def room_initial_sync(self, requester, room_id, pagin_config=None):
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of """Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left. the room this will be what was in the room when they left.
@ -271,32 +268,32 @@ class InitialSyncHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room. A JSON serialisable dict with the snapshot of the room.
""" """
blocked = yield self.store.is_room_blocked(room_id) blocked = await self.store.is_room_blocked(room_id)
if blocked: if blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
user_id = requester.user.to_string() user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable( membership, member_event_id = await self._check_in_room_or_world_readable(
room_id, user_id room_id, user_id
) )
is_peeking = member_event_id is None is_peeking = member_event_id is None
if membership == Membership.JOIN: if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined( result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking user_id, room_id, pagin_config, membership, is_peeking
) )
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted( result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking user_id, room_id, pagin_config, membership, member_event_id, is_peeking
) )
account_data_events = [] account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = await self.store.get_tags_for_room(user_id, room_id)
if tags: if tags:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
account_data = yield self.store.get_account_data_for_room(user_id, room_id) account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():
account_data_events.append({"type": account_data_type, "content": content}) account_data_events.append({"type": account_data_type, "content": content})
@ -304,11 +301,10 @@ class InitialSyncHandler(BaseHandler):
return result return result
@defer.inlineCallbacks async def _room_initial_sync_parted(
def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
): ):
room_state = yield self.state_store.get_state_for_events([member_event_id]) room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id] room_state = room_state[member_event_id]
@ -316,13 +312,13 @@ class InitialSyncHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
stream_token = yield self.store.get_stream_token_for_event(member_event_id) stream_token = await self.store.get_stream_token_for_event(member_event_id)
messages, token = yield self.store.get_recent_events_for_room( messages, token = await self.store.get_recent_events_for_room(
room_id, limit=limit, end_token=stream_token room_id, limit=limit, end_token=stream_token
) )
messages = yield filter_events_for_client( messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking self.storage, user_id, messages, is_peeking=is_peeking
) )
@ -336,13 +332,13 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": ( "chunk": (
yield self._event_serializer.serialize_events(messages, time_now) await self._event_serializer.serialize_events(messages, time_now)
), ),
"start": start_token.to_string(), "start": start_token.to_string(),
"end": end_token.to_string(), "end": end_token.to_string(),
}, },
"state": ( "state": (
yield self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
room_state.values(), time_now room_state.values(), time_now
) )
), ),
@ -350,19 +346,18 @@ class InitialSyncHandler(BaseHandler):
"receipts": [], "receipts": [],
} }
@defer.inlineCallbacks async def _room_initial_sync_joined(
def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking self, user_id, room_id, pagin_config, membership, is_peeking
): ):
current_state = yield self.state.get_current_state(room_id=room_id) current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently # TODO: These concurrently
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
state = yield self._event_serializer.serialize_events( state = await self._event_serializer.serialize_events(
current_state.values(), time_now current_state.values(), time_now
) )
now_token = yield self.hs.get_event_sources().get_current_token() now_token = await self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None
if limit is None: if limit is None:
@ -377,28 +372,26 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler() presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks async def get_presence():
def get_presence():
# If presence is disabled, return an empty list # If presence is disabled, return an empty list
if not self.hs.config.use_presence: if not self.hs.config.use_presence:
return [] return []
states = yield presence_handler.get_states( states = await presence_handler.get_states(
[m.user_id for m in room_members], as_event=True [m.user_id for m in room_members], as_event=True
) )
return states return states
@defer.inlineCallbacks async def get_receipts():
def get_receipts(): receipts = await self.store.get_linearized_receipts_for_room(
receipts = yield self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key room_id, to_key=now_token.receipt_key
) )
if not receipts: if not receipts:
receipts = [] receipts = []
return receipts return receipts
presence, receipts, (messages, token) = yield make_deferred_yieldable( presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background(get_presence), run_in_background(get_presence),
@ -414,7 +407,7 @@ class InitialSyncHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
) )
messages = yield filter_events_for_client( messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking self.storage, user_id, messages, is_peeking=is_peeking
) )
@ -427,7 +420,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": ( "chunk": (
yield self._event_serializer.serialize_events(messages, time_now) await self._event_serializer.serialize_events(messages, time_now)
), ),
"start": start_token.to_string(), "start": start_token.to_string(),
"end": end_token.to_string(), "end": end_token.to_string(),
@ -441,18 +434,17 @@ class InitialSyncHandler(BaseHandler):
return ret return ret
@defer.inlineCallbacks async def _check_in_room_or_world_readable(self, room_id, user_id):
def _check_in_room_or_world_readable(self, room_id, user_id):
try: try:
# check_user_was_in_room will return the most recent membership # check_user_was_in_room will return the most recent membership
# event for the user if: # event for the user if:
# * The user is a non-guest user, and was ever in the room # * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room # * The user is a guest user, and has joined the room
# else it will throw. # else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) member_event = await self.auth.check_user_was_in_room(room_id, user_id)
return member_event.membership, member_event.event_id return member_event.membership, member_event.event_id
except AuthError: except AuthError:
visibility = yield self.state_handler.get_current_state( visibility = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
if ( if (