Make registration idempotent, part 2: be idempotent if the client specifies a username.
This commit is contained in:
parent
48b2e853a8
commit
a7daa5ae13
|
@ -160,6 +160,20 @@ class AuthHandler(BaseHandler):
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
def get_session_id(self, clientdict):
|
||||||
|
"""
|
||||||
|
Gets the session ID for a client given the client dictionary
|
||||||
|
:param clientdict: The dictionary sent by the client in the request
|
||||||
|
:return: The string session ID the client sent. If the client did not
|
||||||
|
send a session ID, returns None.
|
||||||
|
"""
|
||||||
|
sid = None
|
||||||
|
if clientdict and 'auth' in clientdict:
|
||||||
|
authdict = clientdict['auth']
|
||||||
|
if 'session' in authdict:
|
||||||
|
sid = authdict['session']
|
||||||
|
return sid
|
||||||
|
|
||||||
def set_session_data(self, session_id, key, value):
|
def set_session_data(self, session_id, key, value):
|
||||||
"""
|
"""
|
||||||
Store a key-value pair into the sessions data associated with this
|
Store a key-value pair into the sessions data associated with this
|
||||||
|
|
|
@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
self._next_generated_user_id = None
|
self._next_generated_user_id = None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_username(self, localpart, guest_access_token=None):
|
def check_username(self, localpart, guest_access_token=None,
|
||||||
|
assigned_user_id=None):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
if urllib.quote(localpart.encode('utf-8')) != localpart:
|
if urllib.quote(localpart.encode('utf-8')) != localpart:
|
||||||
|
@ -60,6 +61,15 @@ class RegistrationHandler(BaseHandler):
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
if assigned_user_id:
|
||||||
|
if user_id == assigned_user_id:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"A different user ID has already been registered for this session",
|
||||||
|
)
|
||||||
|
|
||||||
yield self.check_user_id_not_appservice_exclusive(user_id)
|
yield self.check_user_id_not_appservice_exclusive(user_id)
|
||||||
|
|
||||||
users = yield self.store.get_users_by_id_case_insensitive(user_id)
|
users = yield self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
|
from synapse.types import UserID
|
||||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
|
|
||||||
|
@ -122,10 +123,25 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
guest_access_token = body.get("guest_access_token", None)
|
guest_access_token = body.get("guest_access_token", None)
|
||||||
|
|
||||||
|
session_id = self.auth_handler.get_session_id(body)
|
||||||
|
logger.error("session id: %r", session_id)
|
||||||
|
registered_user_id = None
|
||||||
|
if session_id:
|
||||||
|
# if we get a registered user id out of here, it means we previously
|
||||||
|
# registered a user for this session, so we could just return the
|
||||||
|
# user here. We carry on and go through the auth checks though,
|
||||||
|
# for paranoia.
|
||||||
|
registered_user_id = self.auth_handler.get_session_data(
|
||||||
|
session_id, "registered_user_id", None
|
||||||
|
)
|
||||||
|
logger.error("already regged: %r", registered_user_id)
|
||||||
|
logger.error("check: %r", desired_username)
|
||||||
|
|
||||||
if desired_username is not None:
|
if desired_username is not None:
|
||||||
yield self.registration_handler.check_username(
|
yield self.registration_handler.check_username(
|
||||||
desired_username,
|
desired_username,
|
||||||
guest_access_token=guest_access_token
|
guest_access_token=guest_access_token,
|
||||||
|
assigned_user_id=registered_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
|
@ -147,10 +163,6 @@ class RegisterRestServlet(RestServlet):
|
||||||
defer.returnValue((401, result))
|
defer.returnValue((401, result))
|
||||||
return
|
return
|
||||||
|
|
||||||
# have we already registered a user for this session
|
|
||||||
registered_user_id = self.auth_handler.get_session_data(
|
|
||||||
session_id, "registered_user_id", None
|
|
||||||
)
|
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
|
|
Loading…
Reference in New Issue