Merge pull request #928 from matrix-org/rav/refactor_login
Refactor login flow
This commit is contained in:
commit
93efcb8526
|
@ -230,7 +230,6 @@ class AuthHandler(BaseHandler):
|
|||
sess = self._get_session_info(session_id)
|
||||
return sess.setdefault('serverdict', {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password_auth(self, authdict, _):
|
||||
if "user" not in authdict or "password" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
@ -240,11 +239,7 @@ class AuthHandler(BaseHandler):
|
|||
if not user_id.startswith('@'):
|
||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||
|
||||
if not (yield self._check_password(user_id, password)):
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
defer.returnValue(user_id)
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_recaptcha(self, authdict, clientip):
|
||||
|
@ -348,67 +343,66 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
return self.sessions[session_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def login_with_password(self, user_id, password):
|
||||
def validate_password_login(self, user_id, password):
|
||||
"""
|
||||
Authenticates the user with their username and password.
|
||||
|
||||
Used only by the v1 login API.
|
||||
|
||||
Args:
|
||||
user_id (str): User ID
|
||||
user_id (str): complete @user:id
|
||||
password (str): Password
|
||||
Returns:
|
||||
A tuple of:
|
||||
The user's ID.
|
||||
The access token for the user's session.
|
||||
The refresh token for the user's session.
|
||||
defer.Deferred: (str) canonical user id
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
StoreError if there was a problem accessing the database
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
|
||||
if not (yield self._check_password(user_id, password)):
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
logger.info("Logging in user %s", user_id)
|
||||
access_token = yield self.issue_access_token(user_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id)
|
||||
defer.returnValue((user_id, access_token, refresh_token))
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_login_tuple_for_user_id(self, user_id):
|
||||
"""
|
||||
Gets login tuple for the user with the given user ID.
|
||||
|
||||
Creates a new access/refresh token for the user.
|
||||
|
||||
The user is assumed to have been authenticated by some other
|
||||
machanism (e.g. CAS)
|
||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||
|
||||
Args:
|
||||
user_id (str): User ID
|
||||
user_id (str): canonical User ID
|
||||
Returns:
|
||||
A tuple of:
|
||||
The user's ID.
|
||||
The access token for the user's session.
|
||||
The refresh token for the user's session.
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
|
||||
logger.info("Logging in user %s", user_id)
|
||||
access_token = yield self.issue_access_token(user_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id)
|
||||
defer.returnValue((user_id, access_token, refresh_token))
|
||||
defer.returnValue((access_token, refresh_token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def does_user_exist(self, user_id):
|
||||
def check_user_exists(self, user_id):
|
||||
"""
|
||||
Checks to see if a user with the given id exists. Will check case
|
||||
insensitively, but return None if there are multiple inexact matches.
|
||||
|
||||
Args:
|
||||
(str) user_id: complete @user:id
|
||||
|
||||
Returns:
|
||||
defer.Deferred: (str) canonical_user_id, or None if zero or
|
||||
multiple matches
|
||||
"""
|
||||
try:
|
||||
yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(True)
|
||||
res = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(res[0])
|
||||
except LoginError:
|
||||
defer.returnValue(False)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _find_user_id_and_pwd_hash(self, user_id):
|
||||
|
@ -438,27 +432,45 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password(self, user_id, password):
|
||||
"""
|
||||
"""Authenticate a user against the LDAP and local databases.
|
||||
|
||||
user_id is checked case insensitively against the local database, but
|
||||
will throw if there are multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
Returns:
|
||||
True if the user_id successfully authenticated
|
||||
(str) the canonical_user_id
|
||||
Raises:
|
||||
LoginError if the password was incorrect
|
||||
"""
|
||||
valid_ldap = yield self._check_ldap_password(user_id, password)
|
||||
if valid_ldap:
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
valid_local_password = yield self._check_local_password(user_id, password)
|
||||
if valid_local_password:
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
result = yield self._check_local_password(user_id, password)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_local_password(self, user_id, password):
|
||||
try:
|
||||
"""Authenticate a user against the local password database.
|
||||
|
||||
user_id is checked case insensitively, but will throw if there are
|
||||
multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
Returns:
|
||||
(str) the canonical_user_id
|
||||
Raises:
|
||||
LoginError if the password was incorrect
|
||||
"""
|
||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(self.validate_hash(password, password_hash))
|
||||
except LoginError:
|
||||
defer.returnValue(False)
|
||||
result = self.validate_hash(password, password_hash)
|
||||
if not result:
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_ldap_password(self, user_id, password):
|
||||
|
@ -570,7 +582,7 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
# check for existing account, if none exists, create one
|
||||
if not (yield self.does_user_exist(user_id)):
|
||||
if not (yield self.check_user_exists(user_id)):
|
||||
# query user metadata for account creation
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
|
|
|
@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
).to_string()
|
||||
|
||||
auth_handler = self.auth_handler
|
||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||
user_id = yield auth_handler.validate_password_login(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"])
|
||||
|
||||
password=login_submission["password"],
|
||||
)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
|
@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
user_id, access_token, refresh_token = (
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
)
|
||||
result = {
|
||||
|
@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if registered_user_id:
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"user_id": registered_user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
|
@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if registered_user_id:
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(registered_user_id)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"user_id": registered_user_id,
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
|
@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if not user_exists:
|
||||
user_id, _ = (
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id, _ = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
|
||||
login_token = auth_handler.generate_short_term_login_token(user_id)
|
||||
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
|
||||
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
||||
login_token)
|
||||
request.redirect(redirect_url)
|
||||
|
|
Loading…
Reference in New Issue