Return user ID in use error straight away
This commit is contained in:
parent
766bd8e880
commit
ea1776f556
|
@ -201,6 +201,8 @@ class AuthHandler(BaseHandler):
|
||||||
logger.debug("Getting validated threepid. threepidcreds: %r" % (threepidCreds,))
|
logger.debug("Getting validated threepid. threepidcreds: %r" % (threepidCreds,))
|
||||||
threepid = yield identity_handler.threepid_from_creds(threepidCreds)
|
threepid = yield identity_handler.threepid_from_creds(threepidCreds)
|
||||||
|
|
||||||
|
threepid['threepidCreds'] = authdict['threepidCreds']
|
||||||
|
|
||||||
defer.returnValue(threepid)
|
defer.returnValue(threepid)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -63,4 +63,27 @@ class IdentityHandler(BaseHandler):
|
||||||
|
|
||||||
if 'medium' in data:
|
if 'medium' in data:
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def bind_threepid(self, creds, mxid):
|
||||||
|
yield run_on_reactor()
|
||||||
|
logger.debug("binding threepid %r to %s", creds, mxid)
|
||||||
|
http_client = SimpleHttpClient(self.hs)
|
||||||
|
data = None
|
||||||
|
try:
|
||||||
|
data = yield http_client.post_urlencoded_get_json(
|
||||||
|
# XXX: Change when ID servers are all HTTPS
|
||||||
|
"http://%s%s" % (
|
||||||
|
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
'sid': creds['sid'],
|
||||||
|
'clientSecret': creds['clientSecret'],
|
||||||
|
'mxid': mxid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.debug("bound threepid %r to %s", creds, mxid)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
data = json.loads(e.msg)
|
||||||
|
defer.returnValue(data)
|
|
@ -44,6 +44,36 @@ class RegistrationHandler(BaseHandler):
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
self.distributor.declare("registered_user")
|
self.distributor.declare("registered_user")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_username(self, localpart):
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
print "checking username %s" % (localpart)
|
||||||
|
|
||||||
|
if urllib.quote(localpart) != localpart:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"User ID must only contain characters which do not"
|
||||||
|
" require URL encoding."
|
||||||
|
)
|
||||||
|
|
||||||
|
user = UserID(localpart, self.hs.hostname)
|
||||||
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
|
||||||
|
print "is valid"
|
||||||
|
|
||||||
|
u = yield self.store.get_user_by_id(user_id)
|
||||||
|
print "user is: "
|
||||||
|
print u
|
||||||
|
if u:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"User ID already taken.",
|
||||||
|
errcode=Codes.USER_IN_USE,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, localpart=None, password=None):
|
def register(self, localpart=None, password=None):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
@ -64,18 +94,11 @@ class RegistrationHandler(BaseHandler):
|
||||||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||||
|
|
||||||
if localpart:
|
if localpart:
|
||||||
if localpart and urllib.quote(localpart) != localpart:
|
self.check_username(localpart)
|
||||||
raise SynapseError(
|
|
||||||
400,
|
|
||||||
"User ID must only contain characters which do not"
|
|
||||||
" require URL encoding."
|
|
||||||
)
|
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
yield self.check_user_id_is_valid(user_id)
|
|
||||||
|
|
||||||
token = self._generate_token(user_id)
|
token = self._generate_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -190,7 +213,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
logger.info("validating theeepidcred sid %s on id server %s",
|
logger.info("validating theeepidcred sid %s on id server %s",
|
||||||
c['sid'], c['idServer'])
|
c['sid'], c['idServer'])
|
||||||
try:
|
try:
|
||||||
threepid = yield self._threepid_from_creds(c)
|
identity_handler = self.hs.get_handlers().identity_handler
|
||||||
|
threepid = yield identity_handler.threepid_from_creds(c)
|
||||||
except:
|
except:
|
||||||
logger.exception("Couldn't validate 3pid")
|
logger.exception("Couldn't validate 3pid")
|
||||||
raise RegistrationError(400, "Couldn't validate 3pid")
|
raise RegistrationError(400, "Couldn't validate 3pid")
|
||||||
|
@ -202,12 +226,16 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def bind_emails(self, user_id, threepidCreds):
|
def bind_emails(self, user_id, threepidCreds):
|
||||||
"""Links emails with a user ID and informs an identity server."""
|
"""Links emails with a user ID and informs an identity server.
|
||||||
|
|
||||||
|
Used only by c/s api v1
|
||||||
|
"""
|
||||||
|
|
||||||
# Now we have a matrix ID, bind it to the threepids we were given
|
# Now we have a matrix ID, bind it to the threepids we were given
|
||||||
for c in threepidCreds:
|
for c in threepidCreds:
|
||||||
|
identity_handler = self.hs.get_handlers().identity_handler
|
||||||
# XXX: This should be a deferred list, shouldn't it?
|
# XXX: This should be a deferred list, shouldn't it?
|
||||||
yield self._bind_threepid(c, user_id)
|
yield identity_handler.bind_threepid(c, user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_user_id_is_valid(self, user_id):
|
def check_user_id_is_valid(self, user_id):
|
||||||
|
@ -234,58 +262,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
def _generate_user_id(self):
|
def _generate_user_id(self):
|
||||||
return "-" + stringutils.random_string(18)
|
return "-" + stringutils.random_string(18)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _threepid_from_creds(self, creds):
|
|
||||||
# TODO: get this from the homeserver rather than creating a new one for
|
|
||||||
# each request
|
|
||||||
http_client = SimpleHttpClient(self.hs)
|
|
||||||
# XXX: make this configurable!
|
|
||||||
trustedIdServers = ['matrix.org:8090', 'matrix.org']
|
|
||||||
if not creds['idServer'] in trustedIdServers:
|
|
||||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
|
||||||
'credentials', creds['idServer'])
|
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
data = {}
|
|
||||||
try:
|
|
||||||
data = yield http_client.get_json(
|
|
||||||
# XXX: This should be HTTPS
|
|
||||||
"http://%s%s" % (
|
|
||||||
creds['idServer'],
|
|
||||||
"/_matrix/identity/api/v1/3pid/getValidated3pid"
|
|
||||||
),
|
|
||||||
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
|
|
||||||
)
|
|
||||||
except CodeMessageException as e:
|
|
||||||
data = json.loads(e.msg)
|
|
||||||
|
|
||||||
if 'medium' in data:
|
|
||||||
defer.returnValue(data)
|
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _bind_threepid(self, creds, mxid):
|
|
||||||
yield
|
|
||||||
logger.debug("binding threepid")
|
|
||||||
http_client = SimpleHttpClient(self.hs)
|
|
||||||
data = None
|
|
||||||
try:
|
|
||||||
data = yield http_client.post_urlencoded_get_json(
|
|
||||||
# XXX: Change when ID servers are all HTTPS
|
|
||||||
"http://%s%s" % (
|
|
||||||
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
|
|
||||||
),
|
|
||||||
{
|
|
||||||
'sid': creds['sid'],
|
|
||||||
'clientSecret': creds['clientSecret'],
|
|
||||||
'mxid': mxid,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug("bound threepid")
|
|
||||||
except CodeMessageException as e:
|
|
||||||
data = json.loads(e.msg)
|
|
||||||
defer.returnValue(data)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _validate_captcha(self, ip_addr, private_key, challenge, response):
|
def _validate_captcha(self, ip_addr, private_key, challenge, response):
|
||||||
"""Validates the captcha provided.
|
"""Validates the captcha provided.
|
||||||
|
|
|
@ -49,12 +49,20 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_handlers().auth_handler
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
body = parse_request_allow_empty(request)
|
body = parse_request_allow_empty(request)
|
||||||
|
if 'password' not in body:
|
||||||
|
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
if 'username' in body:
|
||||||
|
desired_username = body['username']
|
||||||
|
print "username in body"
|
||||||
|
yield self.registration_handler.check_username(desired_username)
|
||||||
|
|
||||||
is_using_shared_secret = False
|
is_using_shared_secret = False
|
||||||
is_application_server = False
|
is_application_server = False
|
||||||
|
@ -100,15 +108,28 @@ class RegisterRestServlet(RestServlet):
|
||||||
if not can_register:
|
if not can_register:
|
||||||
raise SynapseError(403, "Registration has been disabled")
|
raise SynapseError(403, "Registration has been disabled")
|
||||||
|
|
||||||
if 'username' not in params or 'password' not in params:
|
if 'password' not in params:
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||||
desired_username = params['username']
|
desired_username = params['username'] if 'username' in params else None
|
||||||
new_password = params['password']
|
new_password = params['password']
|
||||||
|
|
||||||
(user_id, token) = yield self.registration_handler.register(
|
(user_id, token) = yield self.registration_handler.register(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
password=new_password
|
password=new_password
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if 'bind_email' in params and params['bind_email']:
|
||||||
|
logger.info("bind_email specified: binding")
|
||||||
|
|
||||||
|
emailThreepid = result[LoginType.EMAIL_IDENTITY]
|
||||||
|
threepidCreds = emailThreepid['threepidCreds']
|
||||||
|
logger.debug("Binding emails %s to %s" % (
|
||||||
|
emailThreepid, user_id
|
||||||
|
))
|
||||||
|
yield self.identity_handler.bind_threepid(threepidCreds, user_id)
|
||||||
|
else:
|
||||||
|
logger.info("bind_email not specified: not binding email")
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
|
|
Loading…
Reference in New Issue