Create a `PasswordProvider` wrapper object (#8849)

The idea here is to abstract out all the conditional code which tests which
methods a given password provider has, to provide a consistent interface.
This commit is contained in:
Richard van der Hoff 2020-12-02 10:38:50 +00:00 committed by GitHub
parent edb3d3f827
commit d3ed93504b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 57 deletions

1
changelog.d/8849.misc Normal file
View File

@ -0,0 +1 @@
Refactor `password_auth_provider` support code.

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -25,6 +26,7 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Tuple, Tuple,
Union, Union,
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop # better way to break the loop
account_handler = ModuleApi(hs, self) account_handler = ModuleApi(hs, self)
self.password_providers = [] self.password_providers = [
for module, config in hs.config.password_providers: PasswordProvider.load(module, config, account_handler)
try: for module, config in hs.config.password_providers
self.password_providers.append( ]
module(config=config, account_handler=account_handler)
)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
logger.info("Extra password_providers: %r", self.password_providers) logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
@ -853,6 +850,8 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
login_type = login_submission.get("type") login_type = login_submission.get("type")
if not isinstance(login_type, str):
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
# ideally, we wouldn't be checking the identifier unless we know we have a login # ideally, we wouldn't be checking the identifier unless we know we have a login
# method which uses it (https://github.com/matrix-org/synapse/issues/8836) # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
@ -998,24 +997,12 @@ class AuthHandler(BaseHandler):
qualified_user_id = UserID(username, self.hs.hostname).to_string() qualified_user_id = UserID(username, self.hs.hostname).to_string()
login_type = login_submission.get("type") login_type = login_submission.get("type")
# we already checked that we have a valid login type
assert isinstance(login_type, str)
known_login_type = False known_login_type = False
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
# we've already checked that there is a (valid) password field
is_valid = await provider.check_password(
qualified_user_id, login_submission["password"]
)
if is_valid:
return qualified_user_id, None
if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
):
# this password provider doesn't understand custom login types
continue
supported_login_types = provider.get_supported_login_types() supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types: if login_type not in supported_login_types:
# this password provider doesn't understand this login type # this password provider doesn't understand this login type
@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict) result = await provider.check_auth(username, login_type, login_dict)
if result: if result:
if isinstance(result, str):
result = (result, None)
return result return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
@ -1083,18 +1068,8 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`. unsuccessful, `user_id` and `callback` are both `None`.
""" """
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await provider.check_3pid_auth(medium, address, password) result = await provider.check_3pid_auth(medium, address, password)
if result: if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
return result return result
return None, None return None, None
@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this # see if any of our auth providers want to know about this
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "on_logged_out"): await provider.on_logged_out(
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=user_info.user_id, user_id=user_info.user_id,
device_id=user_info.device_id, device_id=user_info.device_id,
access_token=access_token, access_token=access_token,
) )
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token # delete pushers associated with this access token
if user_info.token_id is not None: if user_info.token_id is not None:
@ -1191,7 +1161,6 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this # see if any of our auth providers want to know about this
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices: for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out( await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token user_id=user_id, device_id=device_id, access_token=token
@ -1519,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
class PasswordProvider:
"""Wrapper for a password auth provider module
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""
@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)
def __init__(self, pp, module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
self._supported_login_types = {}
# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)
g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())
def __str__(self):
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
login_type: the login type being attempted - one of the types returned by
get_supported_login_types()
login_dict: the dictionary of login secrets passed by the client.
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
g = getattr(self._pp, "check_auth", None)
if not g:
return None
result = await g(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result

View File

@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# first delete should give a 401 # first delete should give a 401
channel = self._delete_device(tok1, "dev2") channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)
# there are no valid flows here! # m.login.password UIA is permitted because the auth provider allows it,
self.assertEqual(channel.json_body["flows"], []) # even though the localdb does not.
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
session = channel.json_body["session"] session = channel.json_body["session"]
mock_password_provider.check_password.assert_not_called() mock_password_provider.check_password.assert_not_called()