Create a `PasswordProvider` wrapper object (#8849)
The idea here is to abstract out all the conditional code which tests which methods a given password provider has, to provide a consistent interface.
This commit is contained in:
parent
edb3d3f827
commit
d3ed93504b
|
@ -0,0 +1 @@
|
|||
Refactor `password_auth_provider` support code.
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +26,7 @@ from typing import (
|
|||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
|
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
|
|||
# better way to break the loop
|
||||
account_handler = ModuleApi(hs, self)
|
||||
|
||||
self.password_providers = []
|
||||
for module, config in hs.config.password_providers:
|
||||
try:
|
||||
self.password_providers.append(
|
||||
module(config=config, account_handler=account_handler)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error while initializing %r: %s", module, e)
|
||||
raise
|
||||
self.password_providers = [
|
||||
PasswordProvider.load(module, config, account_handler)
|
||||
for module, config in hs.config.password_providers
|
||||
]
|
||||
|
||||
logger.info("Extra password_providers: %r", self.password_providers)
|
||||
logger.info("Extra password_providers: %s", self.password_providers)
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
|
@ -853,6 +850,8 @@ class AuthHandler(BaseHandler):
|
|||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
login_type = login_submission.get("type")
|
||||
if not isinstance(login_type, str):
|
||||
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
|
||||
|
||||
# ideally, we wouldn't be checking the identifier unless we know we have a login
|
||||
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
|
||||
|
@ -998,24 +997,12 @@ class AuthHandler(BaseHandler):
|
|||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||
|
||||
login_type = login_submission.get("type")
|
||||
# we already checked that we have a valid login type
|
||||
assert isinstance(login_type, str)
|
||||
|
||||
known_login_type = False
|
||||
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
|
||||
known_login_type = True
|
||||
# we've already checked that there is a (valid) password field
|
||||
is_valid = await provider.check_password(
|
||||
qualified_user_id, login_submission["password"]
|
||||
)
|
||||
if is_valid:
|
||||
return qualified_user_id, None
|
||||
|
||||
if not hasattr(provider, "get_supported_login_types") or not hasattr(
|
||||
provider, "check_auth"
|
||||
):
|
||||
# this password provider doesn't understand custom login types
|
||||
continue
|
||||
|
||||
supported_login_types = provider.get_supported_login_types()
|
||||
if login_type not in supported_login_types:
|
||||
# this password provider doesn't understand this login type
|
||||
|
@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
result = await provider.check_auth(username, login_type, login_dict)
|
||||
if result:
|
||||
if isinstance(result, str):
|
||||
result = (result, None)
|
||||
return result
|
||||
|
||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||
|
@ -1083,18 +1068,8 @@ class AuthHandler(BaseHandler):
|
|||
unsuccessful, `user_id` and `callback` are both `None`.
|
||||
"""
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "check_3pid_auth"):
|
||||
# This function is able to return a deferred that either
|
||||
# resolves None, meaning authentication failure, or upon
|
||||
# success, to a str (which is the user_id) or a tuple of
|
||||
# (user_id, callback_func), where callback_func should be run
|
||||
# after we've finished everything else
|
||||
result = await provider.check_3pid_auth(medium, address, password)
|
||||
if result:
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
result = (result, None)
|
||||
return result
|
||||
|
||||
return None, None
|
||||
|
@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "on_logged_out"):
|
||||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = provider.on_logged_out(
|
||||
await provider.on_logged_out(
|
||||
user_id=user_info.user_id,
|
||||
device_id=user_info.device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
# delete pushers associated with this access token
|
||||
if user_info.token_id is not None:
|
||||
|
@ -1191,7 +1161,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "on_logged_out"):
|
||||
for token, token_id, device_id in tokens_and_devices:
|
||||
await provider.on_logged_out(
|
||||
user_id=user_id, device_id=device_id, access_token=token
|
||||
|
@ -1519,3 +1488,127 @@ class MacaroonGenerator:
|
|||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
return macaroon
|
||||
|
||||
|
||||
class PasswordProvider:
|
||||
"""Wrapper for a password auth provider module
|
||||
|
||||
This class abstracts out all of the backwards-compatibility hacks for
|
||||
password providers, to provide a consistent interface.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
|
||||
try:
|
||||
pp = module(config=config, account_handler=module_api)
|
||||
except Exception as e:
|
||||
logger.error("Error while initializing %r: %s", module, e)
|
||||
raise
|
||||
return cls(pp, module_api)
|
||||
|
||||
def __init__(self, pp, module_api: ModuleApi):
|
||||
self._pp = pp
|
||||
self._module_api = module_api
|
||||
|
||||
self._supported_login_types = {}
|
||||
|
||||
# grandfather in check_password support
|
||||
if hasattr(self._pp, "check_password"):
|
||||
self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
||||
|
||||
g = getattr(self._pp, "get_supported_login_types", None)
|
||||
if g:
|
||||
self._supported_login_types.update(g())
|
||||
|
||||
def __str__(self):
|
||||
return str(self._pp)
|
||||
|
||||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||
"""Get the login types supported by this password provider
|
||||
|
||||
Returns a map from a login type identifier (such as m.login.password) to an
|
||||
iterable giving the fields which must be provided by the user in the submission
|
||||
to the /login API.
|
||||
|
||||
This wrapper adds m.login.password to the list if the underlying password
|
||||
provider supports the check_password() api.
|
||||
"""
|
||||
return self._supported_login_types
|
||||
|
||||
async def check_auth(
|
||||
self, username: str, login_type: str, login_dict: JsonDict
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
"""Check if the user has presented valid login credentials
|
||||
|
||||
This wrapper also calls check_password() if the underlying password provider
|
||||
supports the check_password() api and the login type is m.login.password.
|
||||
|
||||
Args:
|
||||
username: user id presented by the client. Either an MXID or an unqualified
|
||||
username.
|
||||
|
||||
login_type: the login type being attempted - one of the types returned by
|
||||
get_supported_login_types()
|
||||
|
||||
login_dict: the dictionary of login secrets passed by the client.
|
||||
|
||||
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
|
||||
user, and `callback` is an optional callback which will be called with the
|
||||
result from the /login call (including access_token, device_id, etc.)
|
||||
"""
|
||||
# first grandfather in a call to check_password
|
||||
if login_type == LoginType.PASSWORD:
|
||||
g = getattr(self._pp, "check_password", None)
|
||||
if g:
|
||||
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
||||
is_valid = await self._pp.check_password(
|
||||
qualified_user_id, login_dict["password"]
|
||||
)
|
||||
if is_valid:
|
||||
return qualified_user_id, None
|
||||
|
||||
g = getattr(self._pp, "check_auth", None)
|
||||
if not g:
|
||||
return None
|
||||
result = await g(username, login_type, login_dict)
|
||||
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
return result, None
|
||||
|
||||
return result
|
||||
|
||||
async def check_3pid_auth(
|
||||
self, medium: str, address: str, password: str
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
g = getattr(self._pp, "check_3pid_auth", None)
|
||||
if not g:
|
||||
return None
|
||||
|
||||
# This function is able to return a deferred that either
|
||||
# resolves None, meaning authentication failure, or upon
|
||||
# success, to a str (which is the user_id) or a tuple of
|
||||
# (user_id, callback_func), where callback_func should be run
|
||||
# after we've finished everything else
|
||||
result = await g(medium, address, password)
|
||||
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
return result, None
|
||||
|
||||
return result
|
||||
|
||||
async def on_logged_out(
|
||||
self, user_id: str, device_id: Optional[str], access_token: str
|
||||
) -> None:
|
||||
g = getattr(self._pp, "on_logged_out", None)
|
||||
if not g:
|
||||
return
|
||||
|
||||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
|
|
@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
# first delete should give a 401
|
||||
channel = self._delete_device(tok1, "dev2")
|
||||
self.assertEqual(channel.code, 401)
|
||||
# there are no valid flows here!
|
||||
self.assertEqual(channel.json_body["flows"], [])
|
||||
# m.login.password UIA is permitted because the auth provider allows it,
|
||||
# even though the localdb does not.
|
||||
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
|
||||
session = channel.json_body["session"]
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
|
|
Loading…
Reference in New Issue