Refactor some logic from LoginRestServlet into AuthHandler
I'm going to need some more flexibility in handling login types in password auth providers, so as a first step, move some stuff from LoginRestServlet into AuthHandler. In particular, we pass everything other than SAML, JWT and token logins down to the AuthHandler, which now has responsibility for checking the login type and fishing the password out of the login dictionary, as well as qualifying the user_id if need be. Ideally SAML, JWT and token would go that way too, but there's no real need for it right now and I'm trying to minimise impact. This commit *should* be non-functional.
This commit is contained in:
parent
e2f4190209
commit
1b65ae00ac
|
@ -77,6 +77,12 @@ class AuthHandler(BaseHandler):
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
self._password_enabled = hs.config.password_enabled
|
||||||
|
|
||||||
|
login_types = set()
|
||||||
|
if self._password_enabled:
|
||||||
|
login_types.add(LoginType.PASSWORD)
|
||||||
|
self._supported_login_types = frozenset(login_types)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -266,10 +272,11 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
user_id = authdict["user"]
|
user_id = authdict["user"]
|
||||||
password = authdict["password"]
|
password = authdict["password"]
|
||||||
if not user_id.startswith('@'):
|
|
||||||
user_id = UserID(user_id, self.hs.hostname).to_string()
|
|
||||||
|
|
||||||
return self._check_password(user_id, password)
|
return self.validate_login(user_id, {
|
||||||
|
"type": LoginType.PASSWORD,
|
||||||
|
"password": password,
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_recaptcha(self, authdict, clientip):
|
def _check_recaptcha(self, authdict, clientip):
|
||||||
|
@ -398,23 +405,6 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
return self.sessions[session_id]
|
return self.sessions[session_id]
|
||||||
|
|
||||||
def validate_password_login(self, user_id, password):
|
|
||||||
"""
|
|
||||||
Authenticates the user with their username and password.
|
|
||||||
|
|
||||||
Used only by the v1 login API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): complete @user:id
|
|
||||||
password (str): Password
|
|
||||||
Returns:
|
|
||||||
defer.Deferred: (str) canonical user id
|
|
||||||
Raises:
|
|
||||||
StoreError if there was a problem accessing the database
|
|
||||||
LoginError if there was an authentication problem.
|
|
||||||
"""
|
|
||||||
return self._check_password(user_id, password)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_access_token_for_user_id(self, user_id, device_id=None,
|
def get_access_token_for_user_id(self, user_id, device_id=None,
|
||||||
initial_display_name=None):
|
initial_display_name=None):
|
||||||
|
@ -501,26 +491,60 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def get_supported_login_types(self):
|
||||||
def _check_password(self, user_id, password):
|
"""Get a the login types supported for the /login API
|
||||||
"""Authenticate a user against the LDAP and local databases.
|
|
||||||
|
|
||||||
user_id is checked case insensitively against the local database, but
|
By default this is just 'm.login.password' (unless password_enabled is
|
||||||
will throw if there are multiple inexact matches.
|
False in the config file), but password auth providers can provide
|
||||||
|
other login types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[str]: login types
|
||||||
|
"""
|
||||||
|
return self._supported_login_types
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def validate_login(self, user_id, login_submission):
|
||||||
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
|
Also used by the user-interactive auth flow to validate
|
||||||
|
m.login.password auth types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): complete @user:id
|
user_id (str): user_id supplied by the user
|
||||||
|
login_submission (dict): the whole of the login submission
|
||||||
|
(including 'type' and other relevant fields)
|
||||||
Returns:
|
Returns:
|
||||||
(str) the canonical_user_id
|
Deferred[str]: canonical user id
|
||||||
Raises:
|
Raises:
|
||||||
LoginError if login fails
|
StoreError if there was a problem accessing the database
|
||||||
|
SynapseError if there was a problem with the request
|
||||||
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not user_id.startswith('@'):
|
||||||
|
user_id = UserID(
|
||||||
|
user_id, self.hs.hostname
|
||||||
|
).to_string()
|
||||||
|
|
||||||
|
login_type = login_submission.get("type")
|
||||||
|
|
||||||
|
if login_type != LoginType.PASSWORD:
|
||||||
|
raise SynapseError(400, "Bad login type.")
|
||||||
|
if not self._password_enabled:
|
||||||
|
raise SynapseError(400, "Password login has been disabled.")
|
||||||
|
if "password" not in login_submission:
|
||||||
|
raise SynapseError(400, "Missing parameter: password")
|
||||||
|
|
||||||
|
password = login_submission["password"]
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
is_valid = yield provider.check_password(user_id, password)
|
is_valid = yield provider.check_password(user_id, password)
|
||||||
if is_valid:
|
if is_valid:
|
||||||
defer.returnValue(user_id)
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
canonical_user_id = yield self._check_local_password(user_id, password)
|
canonical_user_id = yield self._check_local_password(
|
||||||
|
user_id, password,
|
||||||
|
)
|
||||||
|
|
||||||
if canonical_user_id:
|
if canonical_user_id:
|
||||||
defer.returnValue(canonical_user_id)
|
defer.returnValue(canonical_user_id)
|
||||||
|
|
|
@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
|
||||||
|
|
||||||
class LoginRestServlet(ClientV1RestServlet):
|
class LoginRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login$")
|
PATTERNS = client_path_patterns("/login$")
|
||||||
PASS_TYPE = "m.login.password"
|
|
||||||
SAML2_TYPE = "m.login.saml2"
|
SAML2_TYPE = "m.login.saml2"
|
||||||
CAS_TYPE = "m.login.cas"
|
CAS_TYPE = "m.login.cas"
|
||||||
TOKEN_TYPE = "m.login.token"
|
TOKEN_TYPE = "m.login.token"
|
||||||
|
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(LoginRestServlet, self).__init__(hs)
|
super(LoginRestServlet, self).__init__(hs)
|
||||||
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||||
self.password_enabled = hs.config.password_enabled
|
|
||||||
self.saml2_enabled = hs.config.saml2_enabled
|
self.saml2_enabled = hs.config.saml2_enabled
|
||||||
self.jwt_enabled = hs.config.jwt_enabled
|
self.jwt_enabled = hs.config.jwt_enabled
|
||||||
self.jwt_secret = hs.config.jwt_secret
|
self.jwt_secret = hs.config.jwt_secret
|
||||||
|
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
# fall back to the fallback API if they don't understand one of the
|
# fall back to the fallback API if they don't understand one of the
|
||||||
# login flow types returned.
|
# login flow types returned.
|
||||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||||
if self.password_enabled:
|
|
||||||
flows.append({"type": LoginRestServlet.PASS_TYPE})
|
flows.extend((
|
||||||
|
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||||
|
))
|
||||||
|
|
||||||
return (200, {"flows": flows})
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
|
@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
login_submission = parse_json_object_from_request(request)
|
login_submission = parse_json_object_from_request(request)
|
||||||
try:
|
try:
|
||||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
if self.saml2_enabled and (login_submission["type"] ==
|
||||||
if not self.password_enabled:
|
LoginRestServlet.SAML2_TYPE):
|
||||||
raise SynapseError(400, "Password login has been disabled.")
|
|
||||||
|
|
||||||
result = yield self.do_password_login(login_submission)
|
|
||||||
defer.returnValue(result)
|
|
||||||
elif self.saml2_enabled and (login_submission["type"] ==
|
|
||||||
LoginRestServlet.SAML2_TYPE):
|
|
||||||
relay_state = ""
|
relay_state = ""
|
||||||
if "relay_state" in login_submission:
|
if "relay_state" in login_submission:
|
||||||
relay_state = "&RelayState=" + urllib.quote(
|
relay_state = "&RelayState=" + urllib.quote(
|
||||||
|
@ -157,15 +151,21 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
result = yield self.do_token_login(login_submission)
|
result = yield self.do_token_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "Bad login type.")
|
result = yield self._do_other_login(login_submission)
|
||||||
|
defer.returnValue(result)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_password_login(self, login_submission):
|
def _do_other_login(self, login_submission):
|
||||||
if "password" not in login_submission:
|
"""Handle non-token/saml/jwt logins
|
||||||
raise SynapseError(400, "Missing parameter: password")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
login_submission:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(int, object): HTTP code/response
|
||||||
|
"""
|
||||||
login_submission_legacy_convert(login_submission)
|
login_submission_legacy_convert(login_submission)
|
||||||
|
|
||||||
if "identifier" not in login_submission:
|
if "identifier" not in login_submission:
|
||||||
|
@ -208,25 +208,22 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
if "user" not in identifier:
|
if "user" not in identifier:
|
||||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||||
|
|
||||||
user_id = identifier["user"]
|
|
||||||
|
|
||||||
if not user_id.startswith('@'):
|
|
||||||
user_id = UserID(
|
|
||||||
user_id, self.hs.hostname
|
|
||||||
).to_string()
|
|
||||||
|
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = yield auth_handler.validate_password_login(
|
canonical_user_id = yield auth_handler.validate_login(
|
||||||
user_id=user_id,
|
identifier["user"],
|
||||||
password=login_submission["password"],
|
login_submission,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_id = yield self._register_device(
|
||||||
|
canonical_user_id, login_submission,
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id,
|
canonical_user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name"),
|
login_submission.get("initial_device_display_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": canonical_user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
|
|
Loading…
Reference in New Issue