Merge pull request #928 from matrix-org/rav/refactor_login
Refactor login flow
This commit is contained in:
commit
93efcb8526
|
@ -230,7 +230,6 @@ class AuthHandler(BaseHandler):
|
||||||
sess = self._get_session_info(session_id)
|
sess = self._get_session_info(session_id)
|
||||||
return sess.setdefault('serverdict', {}).get(key, default)
|
return sess.setdefault('serverdict', {}).get(key, default)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _check_password_auth(self, authdict, _):
|
def _check_password_auth(self, authdict, _):
|
||||||
if "user" not in authdict or "password" not in authdict:
|
if "user" not in authdict or "password" not in authdict:
|
||||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||||
|
@ -240,11 +239,7 @@ class AuthHandler(BaseHandler):
|
||||||
if not user_id.startswith('@'):
|
if not user_id.startswith('@'):
|
||||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||||
|
|
||||||
if not (yield self._check_password(user_id, password)):
|
return self._check_password(user_id, password)
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
defer.returnValue(user_id)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_recaptcha(self, authdict, clientip):
|
def _check_recaptcha(self, authdict, clientip):
|
||||||
|
@ -348,67 +343,66 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
return self.sessions[session_id]
|
return self.sessions[session_id]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def validate_password_login(self, user_id, password):
|
||||||
def login_with_password(self, user_id, password):
|
|
||||||
"""
|
"""
|
||||||
Authenticates the user with their username and password.
|
Authenticates the user with their username and password.
|
||||||
|
|
||||||
Used only by the v1 login API.
|
Used only by the v1 login API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): User ID
|
user_id (str): complete @user:id
|
||||||
password (str): Password
|
password (str): Password
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
defer.Deferred: (str) canonical user id
|
||||||
The user's ID.
|
|
||||||
The access token for the user's session.
|
|
||||||
The refresh token for the user's session.
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem accessing the database
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
|
return self._check_password(user_id, password)
|
||||||
if not (yield self._check_password(user_id, password)):
|
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
logger.info("Logging in user %s", user_id)
|
|
||||||
access_token = yield self.issue_access_token(user_id)
|
|
||||||
refresh_token = yield self.issue_refresh_token(user_id)
|
|
||||||
defer.returnValue((user_id, access_token, refresh_token))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_login_tuple_for_user_id(self, user_id):
|
def get_login_tuple_for_user_id(self, user_id):
|
||||||
"""
|
"""
|
||||||
Gets login tuple for the user with the given user ID.
|
Gets login tuple for the user with the given user ID.
|
||||||
|
|
||||||
|
Creates a new access/refresh token for the user.
|
||||||
|
|
||||||
The user is assumed to have been authenticated by some other
|
The user is assumed to have been authenticated by some other
|
||||||
machanism (e.g. CAS)
|
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): User ID
|
user_id (str): canonical User ID
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
A tuple of:
|
||||||
The user's ID.
|
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
The refresh token for the user's session.
|
The refresh token for the user's session.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem storing the token.
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
|
|
||||||
|
|
||||||
logger.info("Logging in user %s", user_id)
|
logger.info("Logging in user %s", user_id)
|
||||||
access_token = yield self.issue_access_token(user_id)
|
access_token = yield self.issue_access_token(user_id)
|
||||||
refresh_token = yield self.issue_refresh_token(user_id)
|
refresh_token = yield self.issue_refresh_token(user_id)
|
||||||
defer.returnValue((user_id, access_token, refresh_token))
|
defer.returnValue((access_token, refresh_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def does_user_exist(self, user_id):
|
def check_user_exists(self, user_id):
|
||||||
|
"""
|
||||||
|
Checks to see if a user with the given id exists. Will check case
|
||||||
|
insensitively, but return None if there are multiple inexact matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
(str) user_id: complete @user:id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: (str) canonical_user_id, or None if zero or
|
||||||
|
multiple matches
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
yield self._find_user_id_and_pwd_hash(user_id)
|
res = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
defer.returnValue(True)
|
defer.returnValue(res[0])
|
||||||
except LoginError:
|
except LoginError:
|
||||||
defer.returnValue(False)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _find_user_id_and_pwd_hash(self, user_id):
|
def _find_user_id_and_pwd_hash(self, user_id):
|
||||||
|
@ -438,27 +432,45 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password(self, user_id, password):
|
def _check_password(self, user_id, password):
|
||||||
"""
|
"""Authenticate a user against the LDAP and local databases.
|
||||||
|
|
||||||
|
user_id is checked case insensitively against the local database, but
|
||||||
|
will throw if there are multiple inexact matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): complete @user:id
|
||||||
Returns:
|
Returns:
|
||||||
True if the user_id successfully authenticated
|
(str) the canonical_user_id
|
||||||
|
Raises:
|
||||||
|
LoginError if the password was incorrect
|
||||||
"""
|
"""
|
||||||
valid_ldap = yield self._check_ldap_password(user_id, password)
|
valid_ldap = yield self._check_ldap_password(user_id, password)
|
||||||
if valid_ldap:
|
if valid_ldap:
|
||||||
defer.returnValue(True)
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
valid_local_password = yield self._check_local_password(user_id, password)
|
result = yield self._check_local_password(user_id, password)
|
||||||
if valid_local_password:
|
defer.returnValue(result)
|
||||||
defer.returnValue(True)
|
|
||||||
|
|
||||||
defer.returnValue(False)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_local_password(self, user_id, password):
|
def _check_local_password(self, user_id, password):
|
||||||
try:
|
"""Authenticate a user against the local password database.
|
||||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
|
||||||
defer.returnValue(self.validate_hash(password, password_hash))
|
user_id is checked case insensitively, but will throw if there are
|
||||||
except LoginError:
|
multiple inexact matches.
|
||||||
defer.returnValue(False)
|
|
||||||
|
Args:
|
||||||
|
user_id (str): complete @user:id
|
||||||
|
Returns:
|
||||||
|
(str) the canonical_user_id
|
||||||
|
Raises:
|
||||||
|
LoginError if the password was incorrect
|
||||||
|
"""
|
||||||
|
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
|
result = self.validate_hash(password, password_hash)
|
||||||
|
if not result:
|
||||||
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_ldap_password(self, user_id, password):
|
def _check_ldap_password(self, user_id, password):
|
||||||
|
@ -570,7 +582,7 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check for existing account, if none exists, create one
|
# check for existing account, if none exists, create one
|
||||||
if not (yield self.does_user_exist(user_id)):
|
if not (yield self.check_user_exists(user_id)):
|
||||||
# query user metadata for account creation
|
# query user metadata for account creation
|
||||||
query = "({prop}={value})".format(
|
query = "({prop}={value})".format(
|
||||||
prop=self.ldap_attributes['uid'],
|
prop=self.ldap_attributes['uid'],
|
||||||
|
|
|
@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
).to_string()
|
).to_string()
|
||||||
|
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
user_id = yield auth_handler.validate_password_login(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password=login_submission["password"])
|
password=login_submission["password"],
|
||||||
|
)
|
||||||
|
access_token, refresh_token = (
|
||||||
|
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
|
@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
user_id = (
|
user_id = (
|
||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
)
|
)
|
||||||
user_id, access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
|
@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||||
if user_exists:
|
if registered_user_id:
|
||||||
user_id, access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
registered_user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": registered_user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
|
@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||||
if user_exists:
|
if registered_user_id:
|
||||||
user_id, access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
yield auth_handler.get_login_tuple_for_user_id(registered_user_id)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": registered_user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
|
@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||||
if not user_exists:
|
if not registered_user_id:
|
||||||
user_id, _ = (
|
registered_user_id, _ = (
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
yield self.handlers.registration_handler.register(localpart=user)
|
||||||
)
|
)
|
||||||
|
|
||||||
login_token = auth_handler.generate_short_term_login_token(user_id)
|
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
|
||||||
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
||||||
login_token)
|
login_token)
|
||||||
request.redirect(redirect_url)
|
request.redirect(redirect_url)
|
||||||
|
|
Loading…
Reference in New Issue