When logging in fetch user by user_id case insensitively, *unless* there are multiple case insensitive matches, in which case require the exact user_id
This commit is contained in:
parent
aa3c9c7bd0
commit
42f12ad92f
|
@ -162,7 +162,8 @@ class AuthHandler(BaseHandler):
|
|||
if not user_id.startswith('@'):
|
||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||
|
||||
yield self._check_password(user_id, password)
|
||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
self._check_password(user_id, password, password_hash)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -283,23 +284,37 @@ class AuthHandler(BaseHandler):
|
|||
StoreError if there was a problem storing the token.
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
yield self._check_password(user_id, password)
|
||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
self._check_password(user_id, password, password_hash)
|
||||
|
||||
reg_handler = self.hs.get_handlers().registration_handler
|
||||
access_token = reg_handler.generate_token(user_id)
|
||||
logger.info("Logging in user %s", user_id)
|
||||
yield self.store.add_access_token_to_user(user_id, access_token)
|
||||
defer.returnValue(access_token)
|
||||
defer.returnValue((user_id, access_token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password(self, user_id, password):
|
||||
"""Checks that user_id has passed password, raises LoginError if not."""
|
||||
user_info = yield self.store.get_user_by_id(user_id=user_id)
|
||||
if not user_info:
|
||||
def _find_user_id_and_pwd_hash(self, user_id):
|
||||
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
|
||||
if not user_infos:
|
||||
logger.warn("Attempted to login as %s but they do not exist", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
stored_hash = user_info["password_hash"]
|
||||
if len(user_infos) > 1:
|
||||
if user_id not in user_infos:
|
||||
logger.warn(
|
||||
"Attempted to login as %s but it matches more than one user "
|
||||
"inexactly: %r",
|
||||
user_id, user_infos.keys()
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
defer.returnValue((user_id, user_infos[user_id]))
|
||||
else:
|
||||
defer.returnValue(user_infos.popitem())
|
||||
|
||||
def _check_password(self, user_id, password, stored_hash):
|
||||
"""Checks that user_id has passed password, raises LoginError if not."""
|
||||
if not bcrypt.checkpw(password, stored_hash):
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
|
|
@ -83,9 +83,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID.create(
|
||||
user_id, self.hs.hostname).to_string()
|
||||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
token = yield self.handlers.auth_handler.login_with_password(
|
||||
user_id, token = yield self.handlers.auth_handler.login_with_password(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"])
|
||||
|
||||
|
|
|
@ -99,13 +99,16 @@ class RegistrationStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
def get_users_by_id_case_insensitive(self, user_id):
|
||||
"""Gets users that match user_id case insensitively.
|
||||
Returns a mapping of user_id -> password_hash.
|
||||
"""
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT name, password_hash FROM users"
|
||||
" WHERE name = lower(?)"
|
||||
" WHERE lower(name) = lower(?)"
|
||||
)
|
||||
txn.execute(sql, (user_id,))
|
||||
return self.cursor_to_dict(txn)
|
||||
return dict(txn.fetchall())
|
||||
|
||||
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||
|
||||
|
|
Loading…
Reference in New Issue