Merge erikj/user_dedup to develop

This commit is contained in:
Daniel Wagner-Hall 2015-08-26 13:42:45 +01:00
parent a2355fae7e
commit d3c0e48859
4 changed files with 50 additions and 12 deletions

View File

@ -163,7 +163,8 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
yield self._check_password(user_id, password) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -280,27 +281,49 @@ class AuthHandler(BaseHandler):
password (str): Password password (str): Password
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
yield self._check_password(user_id, password) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks that user_id has passed password, raises LoginError if not.""" """Checks to see if a user with the given id exists. Will check case
user_info = yield self.store.get_user_by_id(user_id=user_id) insensitively, but will throw if there are multiple inexact matches.
if not user_info:
Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
"""
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"] if len(user_infos) > 1:
if user_id not in user_infos:
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not bcrypt.checkpw(password, stored_hash): if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)

View File

@ -56,8 +56,8 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
u = yield self.store.get_user_by_id(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if u: if users:
raise SynapseError( raise SynapseError(
400, 400,
"User ID already taken.", "User ID already taken.",

View File

@ -83,10 +83,11 @@ class LoginRestServlet(ClientV1RestServlet):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create( user_id = UserID.create(
user_id, self.hs.hostname).to_string() user_id, self.hs.hostname
).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.handlers.auth_handler
access_token, refresh_token = yield auth_handler.login_with_password( user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"])

View File

@ -120,6 +120,20 @@ class RegistrationStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn.fetchall())
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
""" """