Add a callback to allow modules to deny 3PID (#11854)
Part of the Tchap Synapse mainlining. This allows modules to implement extra logic to figure out whether a given 3PID can be added to the local homeserver. In the Tchap use case, this will allow a Synapse module to interface with the custom endpoint /internal_info.
This commit is contained in:
parent
fef2e792be
commit
0640f8ebaa
|
@ -0,0 +1 @@
|
||||||
|
Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account.
|
|
@ -166,6 +166,25 @@ any of the subsequent implementations of this callback. If every callback return
|
||||||
the username provided by the user is used, if any (otherwise one is automatically
|
the username provided by the user is used, if any (otherwise one is automatically
|
||||||
generated).
|
generated).
|
||||||
|
|
||||||
|
## `is_3pid_allowed`
|
||||||
|
|
||||||
|
_First introduced in Synapse v1.53.0_
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def is_3pid_allowed(self, medium: str, address: str, registration: bool) -> bool
|
||||||
|
```
|
||||||
|
|
||||||
|
Called when attempting to bind a third-party identifier (i.e. an email address or a phone
|
||||||
|
number). The module is given the medium of the third-party identifier (which is `email` if
|
||||||
|
the identifier is an email address, or `msisdn` if the identifier is a phone number) and
|
||||||
|
its address, as well as a boolean indicating whether the attempt to bind is happening as
|
||||||
|
part of registering a new user. The module must return a boolean indicating whether the
|
||||||
|
identifier can be allowed to be bound to an account on the local homeserver.
|
||||||
|
|
||||||
|
If multiple modules implement this callback, they will be considered in order. If a
|
||||||
|
callback returns `True`, Synapse falls through to the next one. The value of the first
|
||||||
|
callback that does not return `True` will be used. If this happens, Synapse will not call
|
||||||
|
any of the subsequent implementations of this callback.
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
|
|
|
@ -2064,6 +2064,7 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
||||||
[JsonDict, JsonDict],
|
[JsonDict, JsonDict],
|
||||||
Awaitable[Optional[str]],
|
Awaitable[Optional[str]],
|
||||||
]
|
]
|
||||||
|
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
|
||||||
|
|
||||||
|
|
||||||
class PasswordAuthProvider:
|
class PasswordAuthProvider:
|
||||||
|
@ -2079,6 +2080,7 @@ class PasswordAuthProvider:
|
||||||
self.get_username_for_registration_callbacks: List[
|
self.get_username_for_registration_callbacks: List[
|
||||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
] = []
|
] = []
|
||||||
|
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
||||||
|
|
||||||
# Mapping from login type to login parameters
|
# Mapping from login type to login parameters
|
||||||
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
self._supported_login_types: Dict[str, Iterable[str]] = {}
|
||||||
|
@ -2090,6 +2092,7 @@ class PasswordAuthProvider:
|
||||||
self,
|
self,
|
||||||
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
|
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
|
||||||
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
|
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
|
||||||
|
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
|
||||||
auth_checkers: Optional[
|
auth_checkers: Optional[
|
||||||
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
||||||
] = None,
|
] = None,
|
||||||
|
@ -2145,6 +2148,9 @@ class PasswordAuthProvider:
|
||||||
get_username_for_registration,
|
get_username_for_registration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_3pid_allowed is not None:
|
||||||
|
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
|
||||||
|
|
||||||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||||
"""Get the login types supported by this password provider
|
"""Get the login types supported by this password provider
|
||||||
|
|
||||||
|
@ -2343,3 +2349,41 @@ class PasswordAuthProvider:
|
||||||
raise SynapseError(code=500, msg="Internal Server Error")
|
raise SynapseError(code=500, msg="Internal Server Error")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def is_3pid_allowed(
|
||||||
|
self,
|
||||||
|
medium: str,
|
||||||
|
address: str,
|
||||||
|
registration: bool,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the user can be allowed to bind a 3PID on this homeserver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
medium: The medium of the 3PID.
|
||||||
|
address: The address of the 3PID.
|
||||||
|
registration: Whether the 3PID is being bound when registering a new user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the 3PID is allowed to be bound on this homeserver
|
||||||
|
"""
|
||||||
|
for callback in self.is_3pid_allowed_callbacks:
|
||||||
|
try:
|
||||||
|
res = await callback(medium, address, registration)
|
||||||
|
|
||||||
|
if res is False:
|
||||||
|
return res
|
||||||
|
elif not isinstance(res, bool):
|
||||||
|
# mypy complains that this line is unreachable because it assumes the
|
||||||
|
# data returned by the module fits the expected type. We just want
|
||||||
|
# to make sure this is the case.
|
||||||
|
logger.warning( # type: ignore[unreachable]
|
||||||
|
"Ignoring non-string value returned by"
|
||||||
|
" is_3pid_allowed callback %s: %s",
|
||||||
|
callback,
|
||||||
|
res,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Module raised an exception in is_3pid_allowed: %s", e)
|
||||||
|
raise SynapseError(code=500, msg="Internal Server Error")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
|
@ -72,6 +72,7 @@ from synapse.handlers.auth import (
|
||||||
CHECK_3PID_AUTH_CALLBACK,
|
CHECK_3PID_AUTH_CALLBACK,
|
||||||
CHECK_AUTH_CALLBACK,
|
CHECK_AUTH_CALLBACK,
|
||||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
|
||||||
|
IS_3PID_ALLOWED_CALLBACK,
|
||||||
ON_LOGGED_OUT_CALLBACK,
|
ON_LOGGED_OUT_CALLBACK,
|
||||||
AuthHandler,
|
AuthHandler,
|
||||||
)
|
)
|
||||||
|
@ -312,6 +313,7 @@ class ModuleApi:
|
||||||
auth_checkers: Optional[
|
auth_checkers: Optional[
|
||||||
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
||||||
] = None,
|
] = None,
|
||||||
|
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
|
||||||
get_username_for_registration: Optional[
|
get_username_for_registration: Optional[
|
||||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
|
@ -323,6 +325,7 @@ class ModuleApi:
|
||||||
return self._password_auth_provider.register_password_auth_provider_callbacks(
|
return self._password_auth_provider.register_password_auth_provider_callbacks(
|
||||||
check_3pid_auth=check_3pid_auth,
|
check_3pid_auth=check_3pid_auth,
|
||||||
on_logged_out=on_logged_out,
|
on_logged_out=on_logged_out,
|
||||||
|
is_3pid_allowed=is_3pid_allowed,
|
||||||
auth_checkers=auth_checkers,
|
auth_checkers=auth_checkers,
|
||||||
get_username_for_registration=get_username_for_registration,
|
get_username_for_registration=get_username_for_registration,
|
||||||
)
|
)
|
||||||
|
|
|
@ -385,7 +385,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||||
send_attempt = body["send_attempt"]
|
send_attempt = body["send_attempt"]
|
||||||
next_link = body.get("next_link") # Optional param
|
next_link = body.get("next_link") # Optional param
|
||||||
|
|
||||||
if not check_3pid_allowed(self.hs, "email", email):
|
if not await check_3pid_allowed(self.hs, "email", email):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Your email domain is not authorized on this server",
|
"Your email domain is not authorized on this server",
|
||||||
|
@ -468,7 +468,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
msisdn = phone_number_to_msisdn(country, phone_number)
|
msisdn = phone_number_to_msisdn(country, phone_number)
|
||||||
|
|
||||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Account phone numbers are not authorized on this server",
|
"Account phone numbers are not authorized on this server",
|
||||||
|
|
|
@ -112,7 +112,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||||
send_attempt = body["send_attempt"]
|
send_attempt = body["send_attempt"]
|
||||||
next_link = body.get("next_link") # Optional param
|
next_link = body.get("next_link") # Optional param
|
||||||
|
|
||||||
if not check_3pid_allowed(self.hs, "email", email):
|
if not await check_3pid_allowed(self.hs, "email", email, registration=True):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Your email domain is not authorized to register on this server",
|
"Your email domain is not authorized to register on this server",
|
||||||
|
@ -192,7 +192,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
msisdn = phone_number_to_msisdn(country, phone_number)
|
msisdn = phone_number_to_msisdn(country, phone_number)
|
||||||
|
|
||||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
if not await check_3pid_allowed(self.hs, "msisdn", msisdn, registration=True):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Phone numbers are not authorized to register on this server",
|
"Phone numbers are not authorized to register on this server",
|
||||||
|
@ -616,7 +616,9 @@ class RegisterRestServlet(RestServlet):
|
||||||
medium = auth_result[login_type]["medium"]
|
medium = auth_result[login_type]["medium"]
|
||||||
address = auth_result[login_type]["address"]
|
address = auth_result[login_type]["address"]
|
||||||
|
|
||||||
if not check_3pid_allowed(self.hs, medium, address):
|
if not await check_3pid_allowed(
|
||||||
|
self.hs, medium, address, registration=True
|
||||||
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Third party identifiers (email/phone numbers)"
|
"Third party identifiers (email/phone numbers)"
|
||||||
|
|
|
@ -32,7 +32,12 @@ logger = logging.getLogger(__name__)
|
||||||
MAX_EMAIL_ADDRESS_LENGTH = 500
|
MAX_EMAIL_ADDRESS_LENGTH = 500
|
||||||
|
|
||||||
|
|
||||||
def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
|
async def check_3pid_allowed(
|
||||||
|
hs: "HomeServer",
|
||||||
|
medium: str,
|
||||||
|
address: str,
|
||||||
|
registration: bool = False,
|
||||||
|
) -> bool:
|
||||||
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -40,9 +45,15 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
|
||||||
medium: 3pid medium - e.g. email, msisdn
|
medium: 3pid medium - e.g. email, msisdn
|
||||||
address: address within that medium (e.g. "wotan@matrix.org")
|
address: address within that medium (e.g. "wotan@matrix.org")
|
||||||
msisdns need to first have been canonicalised
|
msisdns need to first have been canonicalised
|
||||||
|
registration: whether we want to bind the 3PID as part of registering a new user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: whether the 3PID medium/address is allowed to be added to this HS
|
bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||||
"""
|
"""
|
||||||
|
if not await hs.get_password_auth_provider().is_3pid_allowed(
|
||||||
|
medium, address, registration
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
if hs.config.registration.allowed_local_3pids:
|
if hs.config.registration.allowed_local_3pids:
|
||||||
for constraint in hs.config.registration.allowed_local_3pids:
|
for constraint in hs.config.registration.allowed_local_3pids:
|
||||||
|
|
|
@ -21,13 +21,15 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
|
from synapse.api.errors import Codes
|
||||||
from synapse.handlers.auth import load_legacy_password_auth_providers
|
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.rest.client import devices, login, logout, register
|
from synapse.rest.client import account, devices, login, logout, register
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
|
||||||
# (possibly experimental) login flows we expect to appear in the list after the normal
|
# (possibly experimental) login flows we expect to appear in the list after the normal
|
||||||
|
@ -158,6 +160,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
devices.register_servlets,
|
devices.register_servlets,
|
||||||
logout.register_servlets,
|
logout.register_servlets,
|
||||||
register.register_servlets,
|
register.register_servlets,
|
||||||
|
account.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -803,6 +806,77 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# Check that the callback has been called.
|
# Check that the callback has been called.
|
||||||
m.assert_called_once()
|
m.assert_called_once()
|
||||||
|
|
||||||
|
# Set some email configuration so the test doesn't fail because of its absence.
|
||||||
|
@override_config({"email": {"notif_from": "noreply@test"}})
|
||||||
|
def test_3pid_allowed(self):
|
||||||
|
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
|
||||||
|
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
|
||||||
|
the 3PID. Also checks that the module is passed a boolean indicating whether the
|
||||||
|
user to bind this 3PID to is currently registering.
|
||||||
|
"""
|
||||||
|
self._test_3pid_allowed("rin", False)
|
||||||
|
self._test_3pid_allowed("kitay", True)
|
||||||
|
|
||||||
|
def _test_3pid_allowed(self, username: str, registration: bool):
|
||||||
|
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
|
||||||
|
either /register or /account URLs depending on the arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: The username to use for the test.
|
||||||
|
registration: Whether to test with registration URLs.
|
||||||
|
"""
|
||||||
|
self.hs.get_identity_handler().send_threepid_validation = Mock(
|
||||||
|
return_value=make_awaitable(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
m = Mock(return_value=make_awaitable(False))
|
||||||
|
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
|
||||||
|
|
||||||
|
self.register_user(username, "password")
|
||||||
|
tok = self.login(username, "password")
|
||||||
|
|
||||||
|
if registration:
|
||||||
|
url = "/register/email/requestToken"
|
||||||
|
else:
|
||||||
|
url = "/account/3pid/email/requestToken"
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
{
|
||||||
|
"client_secret": "foo",
|
||||||
|
"email": "foo@test.com",
|
||||||
|
"send_attempt": 0,
|
||||||
|
},
|
||||||
|
access_token=tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 403, channel.result)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"],
|
||||||
|
Codes.THREEPID_DENIED,
|
||||||
|
channel.json_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
m.assert_called_once_with("email", "foo@test.com", registration)
|
||||||
|
|
||||||
|
m = Mock(return_value=make_awaitable(True))
|
||||||
|
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
{
|
||||||
|
"client_secret": "foo",
|
||||||
|
"email": "bar@test.com",
|
||||||
|
"send_attempt": 0,
|
||||||
|
},
|
||||||
|
access_token=tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertIn("sid", channel.json_body)
|
||||||
|
|
||||||
|
m.assert_called_once_with("email", "bar@test.com", registration)
|
||||||
|
|
||||||
def _setup_get_username_for_registration(self) -> Mock:
|
def _setup_get_username_for_registration(self) -> Mock:
|
||||||
"""Registers a get_username_for_registration callback that appends "-foo" to the
|
"""Registers a get_username_for_registration callback that appends "-foo" to the
|
||||||
username the client is trying to register.
|
username the client is trying to register.
|
||||||
|
|
Loading…
Reference in New Issue