Convert the registration handler to async/await. (#7649)

This commit is contained in:
Patrick Cloke 2020-06-08 11:15:02 -04:00 committed by GitHub
parent 375ca0cceb
commit 3c45a78090
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 68 deletions

1
changelog.d/7649.misc Normal file
View File

@ -0,0 +1 @@
Convert registration handler to async/await.

View File

@ -16,8 +16,6 @@
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging import logging
from twisted.internet import defer
from synapse import types from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, LoginType from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@ -75,8 +73,9 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime self.session_lifetime = hs.config.session_lifetime
@defer.inlineCallbacks async def check_username(
def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): self, localpart, guest_access_token=None, assigned_user_id=None
):
if types.contains_invalid_mxid_characters(localpart): if types.contains_invalid_mxid_characters(localpart):
raise SynapseError( raise SynapseError(
400, 400,
@ -113,13 +112,13 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME, Codes.INVALID_USERNAME,
) )
users = yield self.store.get_users_by_id_case_insensitive(user_id) users = await self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
if not guest_access_token: if not guest_access_token:
raise SynapseError( raise SynapseError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
user_data = yield self.auth.get_user_by_access_token(guest_access_token) user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart: if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError( raise AuthError(
403, 403,
@ -137,8 +136,7 @@ class RegistrationHandler(BaseHandler):
except ValueError: except ValueError:
pass pass
@defer.inlineCallbacks async def register_user(
def register_user(
self, self,
localpart=None, localpart=None,
password_hash=None, password_hash=None,
@ -169,18 +167,18 @@ class RegistrationHandler(BaseHandler):
by_admin (bool): True if this registration is being made via the by_admin (bool): True if this registration is being made via the
admin api, otherwise False. admin api, otherwise False.
Returns: Returns:
Deferred[str]: user_id str: user_id
Raises: Raises:
SynapseError if there was a problem registering. SynapseError if there was a problem registering.
""" """
yield self.check_registration_ratelimit(address) self.check_registration_ratelimit(address)
# do not check_auth_blocking if the call is coming through the Admin API # do not check_auth_blocking if the call is coming through the Admin API
if not by_admin: if not by_admin:
yield self.auth.check_auth_blocking(threepid=threepid) await self.auth.check_auth_blocking(threepid=threepid)
if localpart is not None: if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token) await self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None was_guest = guest_access_token is not None
@ -194,7 +192,7 @@ class RegistrationHandler(BaseHandler):
elif default_display_name is None: elif default_display_name is None:
default_display_name = localpart default_display_name = localpart
yield self.register_with_store( await self.register_with_store(
user_id=user_id, user_id=user_id,
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
@ -206,11 +204,9 @@ class RegistrationHandler(BaseHandler):
) )
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart) profile = await self.store.get_profileinfo(localpart)
yield defer.ensureDeferred( await self.user_directory_handler.handle_local_profile_change(
self.user_directory_handler.handle_local_profile_change( user_id, profile
user_id, profile
)
) )
else: else:
@ -222,14 +218,14 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10: if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID") raise SynapseError(500, "Unable to find a suitable guest user ID")
localpart = yield self._generate_user_id() localpart = await self._generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id) self.check_user_id_not_appservice_exclusive(user_id)
if default_display_name is None: if default_display_name is None:
default_display_name = localpart default_display_name = localpart
try: try:
yield self.register_with_store( await self.register_with_store(
user_id=user_id, user_id=user_id,
password_hash=password_hash, password_hash=password_hash,
make_guest=make_guest, make_guest=make_guest,
@ -252,7 +248,7 @@ class RegistrationHandler(BaseHandler):
user_id, user_id,
) )
else: else:
yield defer.ensureDeferred(self._auto_join_rooms(user_id)) await self._auto_join_rooms(user_id)
else: else:
logger.info( logger.info(
"Skipping auto-join for %s because consent is required at registration", "Skipping auto-join for %s because consent is required at registration",
@ -270,7 +266,7 @@ class RegistrationHandler(BaseHandler):
} }
# Bind email to new account # Bind email to new account
yield self._register_email_threepid(user_id, threepid_dict, None) await self._register_email_threepid(user_id, threepid_dict, None)
return user_id return user_id
@ -335,8 +331,7 @@ class RegistrationHandler(BaseHandler):
""" """
await self._auto_join_rooms(user_id) await self._auto_join_rooms(user_id)
@defer.inlineCallbacks async def appservice_register(self, user_localpart, as_token):
def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
@ -351,11 +346,9 @@ class RegistrationHandler(BaseHandler):
service_id = service.id if service.is_exclusive_user(user_id) else None service_id = service.id if service.is_exclusive_user(user_id) else None
yield self.check_user_id_not_appservice_exclusive( self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
user_id, allowed_appservice=service
)
yield self.register_with_store( await self.register_with_store(
user_id=user_id, user_id=user_id,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
@ -387,13 +380,12 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
@defer.inlineCallbacks async def _generate_user_id(self):
def _generate_user_id(self):
if self._next_generated_user_id is None: if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())): with await self._generate_user_id_linearizer.queue(()):
if self._next_generated_user_id is None: if self._next_generated_user_id is None:
self._next_generated_user_id = ( self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart() await self.store.find_next_generated_user_id_localpart()
) )
id = self._next_generated_user_id id = self._next_generated_user_id
@ -496,8 +488,9 @@ class RegistrationHandler(BaseHandler):
user_type=user_type, user_type=user_type,
) )
@defer.inlineCallbacks async def register_device(
def register_device(self, user_id, device_id, initial_display_name, is_guest=False): self, user_id, device_id, initial_display_name, is_guest=False
):
"""Register a device for a user and generate an access token. """Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config. The access token will be limited by the homeserver's session_lifetime config.
@ -511,11 +504,11 @@ class RegistrationHandler(BaseHandler):
is_guest (bool): Whether this is a guest account is_guest (bool): Whether this is a guest account
Returns: Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token tuple[str, str]: Tuple of device ID and access token
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
r = yield self._register_device_client( r = await self._register_device_client(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_display_name=initial_display_name, initial_display_name=initial_display_name,
@ -531,7 +524,7 @@ class RegistrationHandler(BaseHandler):
) )
valid_until_ms = self.clock.time_msec() + self.session_lifetime valid_until_ms = self.clock.time_msec() + self.session_lifetime
device_id = yield self.device_handler.check_device_registered( device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
if is_guest: if is_guest:
@ -540,10 +533,8 @@ class RegistrationHandler(BaseHandler):
user_id, ["guest = true"] user_id, ["guest = true"]
) )
else: else:
access_token = yield defer.ensureDeferred( access_token = await self._auth_handler.get_access_token_for_user_id(
self._auth_handler.get_access_token_for_user_id( user_id, device_id=device_id, valid_until_ms=valid_until_ms
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
) )
return (device_id, access_token) return (device_id, access_token)
@ -594,8 +585,7 @@ class RegistrationHandler(BaseHandler):
await self.store.user_set_consent_version(user_id, consent_version) await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id) await self.post_consent_actions(user_id)
@defer.inlineCallbacks async def _register_email_threepid(self, user_id, threepid, token):
def _register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier """Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the Also adds an email pusher for the email address, if configured in the
@ -608,8 +598,6 @@ class RegistrationHandler(BaseHandler):
threepid (object): m.login.email.identity auth response threepid (object): m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged token (str|None): access_token for the user, or None if not logged
in. in.
Returns:
defer.Deferred:
""" """
reqd = ("medium", "address", "validated_at") reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd): if any(x not in threepid for x in reqd):
@ -617,13 +605,8 @@ class RegistrationHandler(BaseHandler):
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
return return
yield defer.ensureDeferred( await self._auth_handler.add_threepid(
self._auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
) )
# And we add an email pusher for them by default, but only # And we add an email pusher for them by default, but only
@ -639,10 +622,10 @@ class RegistrationHandler(BaseHandler):
# It would really make more sense for this to be passed # It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an # up when the access token is saved, but that's quite an
# invasive change I'd rather do separately. # invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(token) user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"] token_id = user_tuple["token_id"]
yield self.pusher_pool.add_pusher( await self.pusher_pool.add_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="email", kind="email",
@ -654,8 +637,7 @@ class RegistrationHandler(BaseHandler):
data={}, data={},
) )
@defer.inlineCallbacks async def _register_msisdn_threepid(self, user_id, threepid):
def _register_msisdn_threepid(self, user_id, threepid):
"""Add a phone number as a 3pid identifier """Add a phone number as a 3pid identifier
Must be called on master. Must be called on master.
@ -663,8 +645,6 @@ class RegistrationHandler(BaseHandler):
Args: Args:
user_id (str): id of user user_id (str): id of user
threepid (object): m.login.msisdn auth response threepid (object): m.login.msisdn auth response
Returns:
defer.Deferred:
""" """
try: try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
@ -675,11 +655,6 @@ class RegistrationHandler(BaseHandler):
return None return None
raise raise
yield defer.ensureDeferred( await self._auth_handler.add_threepid(
self._auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
) )

View File

@ -128,8 +128,12 @@ class ModuleApi(object):
Returns: Returns:
Deferred[str]: user_id Deferred[str]: user_id
""" """
return self._hs.get_registration_handler().register_user( return defer.ensureDeferred(
localpart=localpart, default_display_name=displayname, bind_emails=emails self._hs.get_registration_handler().register_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
)
) )
def register_device(self, user_id, device_id=None, initial_display_name=None): def register_device(self, user_id, device_id=None, initial_display_name=None):