Allow denying or shadow banning registrations via the spam checker (#8034)
This commit is contained in:
parent
e259d63f73
commit
3f91638da6
|
@ -0,0 +1 @@
|
||||||
|
Add support for shadow-banning users (ignoring any message send requests).
|
|
@ -15,9 +15,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.spam_checker_api import SpamCheckerApi
|
from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
|
||||||
|
from synapse.types import Collection
|
||||||
|
|
||||||
MYPY = False
|
MYPY = False
|
||||||
if MYPY:
|
if MYPY:
|
||||||
|
@ -160,3 +161,33 @@ class SpamChecker(object):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def check_registration_for_spam(
|
||||||
|
self,
|
||||||
|
email_threepid: Optional[dict],
|
||||||
|
username: Optional[str],
|
||||||
|
request_info: Collection[Tuple[str, str]],
|
||||||
|
) -> RegistrationBehaviour:
|
||||||
|
"""Checks if we should allow the given registration request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email_threepid: The email threepid used for registering, if any
|
||||||
|
username: The request user name, if any
|
||||||
|
request_info: List of tuples of user agent and IP that
|
||||||
|
were used during the registration process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enum for how the request should be handled
|
||||||
|
"""
|
||||||
|
|
||||||
|
for spam_checker in self.spam_checkers:
|
||||||
|
# For backwards compatibility, only run if the method exists on the
|
||||||
|
# spam checker
|
||||||
|
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
||||||
|
if checker:
|
||||||
|
behaviour = checker(email_threepid, username, request_info)
|
||||||
|
assert isinstance(behaviour, RegistrationBehaviour)
|
||||||
|
if behaviour != RegistrationBehaviour.ALLOW:
|
||||||
|
return behaviour
|
||||||
|
|
||||||
|
return RegistrationBehaviour.ALLOW
|
||||||
|
|
|
@ -364,6 +364,14 @@ class AuthHandler(BaseHandler):
|
||||||
# authentication flow.
|
# authentication flow.
|
||||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||||
|
|
||||||
|
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
|
||||||
|
0
|
||||||
|
].decode("ascii", "surrogateescape")
|
||||||
|
|
||||||
|
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||||
|
session.session_id, user_agent, clientip
|
||||||
|
)
|
||||||
|
|
||||||
if not authdict:
|
if not authdict:
|
||||||
raise InteractiveAuthIncompleteError(
|
raise InteractiveAuthIncompleteError(
|
||||||
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
|
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
|
||||||
|
|
|
@ -35,6 +35,7 @@ class CasHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
self._hostname = hs.hostname
|
self._hostname = hs.hostname
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
|
@ -210,8 +211,16 @@ class CasHandler:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if not registered_user_id:
|
if not registered_user_id:
|
||||||
|
# Pull out the user-agent and IP from the request.
|
||||||
|
user_agent = request.requestHeaders.getRawHeaders(
|
||||||
|
b"User-Agent", default=[b""]
|
||||||
|
)[0].decode("ascii", "surrogateescape")
|
||||||
|
ip_address = self.hs.get_ip_from_request(request)
|
||||||
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
localpart=localpart, default_display_name=user_display_name
|
localpart=localpart,
|
||||||
|
default_display_name=user_display_name,
|
||||||
|
user_agent_ips=(user_agent, ip_address),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
|
|
|
@ -93,6 +93,7 @@ class OidcHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||||
self._client_auth = ClientAuth(
|
self._client_auth = ClientAuth(
|
||||||
|
@ -689,9 +690,17 @@ class OidcHandler:
|
||||||
self._render_error(request, "invalid_token", str(e))
|
self._render_error(request, "invalid_token", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Pull out the user-agent and IP from the request.
|
||||||
|
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
|
||||||
|
0
|
||||||
|
].decode("ascii", "surrogateescape")
|
||||||
|
ip_address = self.hs.get_ip_from_request(request)
|
||||||
|
|
||||||
# Call the mapper to register/login the user
|
# Call the mapper to register/login the user
|
||||||
try:
|
try:
|
||||||
user_id = await self._map_userinfo_to_user(userinfo, token)
|
user_id = await self._map_userinfo_to_user(
|
||||||
|
userinfo, token, user_agent, ip_address
|
||||||
|
)
|
||||||
except MappingException as e:
|
except MappingException as e:
|
||||||
logger.exception("Could not map user")
|
logger.exception("Could not map user")
|
||||||
self._render_error(request, "mapping_error", str(e))
|
self._render_error(request, "mapping_error", str(e))
|
||||||
|
@ -828,7 +837,9 @@ class OidcHandler:
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
|
async def _map_userinfo_to_user(
|
||||||
|
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
|
||||||
|
) -> str:
|
||||||
"""Maps a UserInfo object to a mxid.
|
"""Maps a UserInfo object to a mxid.
|
||||||
|
|
||||||
UserInfo should have a claim that uniquely identifies users. This claim
|
UserInfo should have a claim that uniquely identifies users. This claim
|
||||||
|
@ -843,6 +854,8 @@ class OidcHandler:
|
||||||
Args:
|
Args:
|
||||||
userinfo: an object representing the user
|
userinfo: an object representing the user
|
||||||
token: a dict with the tokens obtained from the provider
|
token: a dict with the tokens obtained from the provider
|
||||||
|
user_agent: The user agent of the client making the request.
|
||||||
|
ip_address: The IP address of the client making the request.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MappingException: if there was an error while mapping some properties
|
MappingException: if there was an error while mapping some properties
|
||||||
|
@ -899,7 +912,9 @@ class OidcHandler:
|
||||||
# It's the first time this user is logging in and the mapped mxid was
|
# It's the first time this user is logging in and the mapped mxid was
|
||||||
# not taken, register the user
|
# not taken, register the user
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
localpart=localpart, default_display_name=attributes["display_name"],
|
localpart=localpart,
|
||||||
|
default_display_name=attributes["display_name"],
|
||||||
|
user_agent_ips=(user_agent, ip_address),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._datastore.record_user_external_id(
|
await self._datastore.record_user_external_id(
|
||||||
|
|
|
@ -26,6 +26,7 @@ from synapse.replication.http.register import (
|
||||||
ReplicationPostRegisterActionsServlet,
|
ReplicationPostRegisterActionsServlet,
|
||||||
ReplicationRegisterServlet,
|
ReplicationRegisterServlet,
|
||||||
)
|
)
|
||||||
|
from synapse.spam_checker_api import RegistrationBehaviour
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import RoomAlias, UserID, create_requester
|
from synapse.types import RoomAlias, UserID, create_requester
|
||||||
|
|
||||||
|
@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||||
|
|
||||||
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
if hs.config.worker_app:
|
if hs.config.worker_app:
|
||||||
self._register_client = ReplicationRegisterServlet.make_client(hs)
|
self._register_client = ReplicationRegisterServlet.make_client(hs)
|
||||||
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
|
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
|
||||||
|
@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
address=None,
|
address=None,
|
||||||
bind_emails=[],
|
bind_emails=[],
|
||||||
by_admin=False,
|
by_admin=False,
|
||||||
shadow_banned=False,
|
user_agent_ips=None,
|
||||||
):
|
):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
|
@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
bind_emails (List[str]): list of emails to bind to this account.
|
bind_emails (List[str]): list of emails to bind to this account.
|
||||||
by_admin (bool): True if this registration is being made via the
|
by_admin (bool): True if this registration is being made via the
|
||||||
admin api, otherwise False.
|
admin api, otherwise False.
|
||||||
shadow_banned (bool): Shadow-ban the created user.
|
user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
|
||||||
|
during the registration process.
|
||||||
Returns:
|
Returns:
|
||||||
str: user_id
|
str: user_id
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
self.check_registration_ratelimit(address)
|
self.check_registration_ratelimit(address)
|
||||||
|
|
||||||
|
result = self.spam_checker.check_registration_for_spam(
|
||||||
|
threepid, localpart, user_agent_ips or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
if result == RegistrationBehaviour.DENY:
|
||||||
|
logger.info(
|
||||||
|
"Blocked registration of %r", localpart,
|
||||||
|
)
|
||||||
|
# We return a 429 to make it not obvious that they've been
|
||||||
|
# denied.
|
||||||
|
raise SynapseError(429, "Rate limited")
|
||||||
|
|
||||||
|
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
|
||||||
|
if shadow_banned:
|
||||||
|
logger.info(
|
||||||
|
"Shadow banning registration of %r", localpart,
|
||||||
|
)
|
||||||
|
|
||||||
# do not check_auth_blocking if the call is coming through the Admin API
|
# do not check_auth_blocking if the call is coming through the Admin API
|
||||||
if not by_admin:
|
if not by_admin:
|
||||||
await self.auth.check_auth_blocking(threepid=threepid)
|
await self.auth.check_auth_blocking(threepid=threepid)
|
||||||
|
|
|
@ -54,6 +54,7 @@ class Saml2SessionData:
|
||||||
|
|
||||||
class SamlHandler:
|
class SamlHandler:
|
||||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
@ -133,8 +134,14 @@ class SamlHandler:
|
||||||
# the dict.
|
# the dict.
|
||||||
self.expire_sessions()
|
self.expire_sessions()
|
||||||
|
|
||||||
|
# Pull out the user-agent and IP from the request.
|
||||||
|
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
|
||||||
|
0
|
||||||
|
].decode("ascii", "surrogateescape")
|
||||||
|
ip_address = self.hs.get_ip_from_request(request)
|
||||||
|
|
||||||
user_id, current_session = await self._map_saml_response_to_user(
|
user_id, current_session = await self._map_saml_response_to_user(
|
||||||
resp_bytes, relay_state
|
resp_bytes, relay_state, user_agent, ip_address
|
||||||
)
|
)
|
||||||
|
|
||||||
# Complete the interactive auth session or the login.
|
# Complete the interactive auth session or the login.
|
||||||
|
@ -147,7 +154,11 @@ class SamlHandler:
|
||||||
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||||
|
|
||||||
async def _map_saml_response_to_user(
|
async def _map_saml_response_to_user(
|
||||||
self, resp_bytes: str, client_redirect_url: str
|
self,
|
||||||
|
resp_bytes: str,
|
||||||
|
client_redirect_url: str,
|
||||||
|
user_agent: str,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[str, Optional[Saml2SessionData]]:
|
) -> Tuple[str, Optional[Saml2SessionData]]:
|
||||||
"""
|
"""
|
||||||
Given a sample response, retrieve the cached session and user for it.
|
Given a sample response, retrieve the cached session and user for it.
|
||||||
|
@ -155,6 +166,8 @@ class SamlHandler:
|
||||||
Args:
|
Args:
|
||||||
resp_bytes: The SAML response.
|
resp_bytes: The SAML response.
|
||||||
client_redirect_url: The redirect URL passed in by the client.
|
client_redirect_url: The redirect URL passed in by the client.
|
||||||
|
user_agent: The user agent of the client making the request.
|
||||||
|
ip_address: The IP address of the client making the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of the user ID and SAML session associated with this response.
|
Tuple of the user ID and SAML session associated with this response.
|
||||||
|
@ -291,6 +304,7 @@ class SamlHandler:
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
bind_emails=emails,
|
bind_emails=emails,
|
||||||
|
user_agent_ips=(user_agent, ip_address),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._datastore.record_user_external_id(
|
await self._datastore.record_user_external_id(
|
||||||
|
|
|
@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet):
|
||||||
Codes.THREEPID_IN_USE,
|
Codes.THREEPID_IN_USE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
entries = await self.store.get_user_agents_ips_to_ui_auth_session(
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
|
||||||
registered_user_id = await self.registration_handler.register_user(
|
registered_user_id = await self.registration_handler.register_user(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
guest_access_token=guest_access_token,
|
guest_access_token=guest_access_token,
|
||||||
threepid=threepid,
|
threepid=threepid,
|
||||||
address=client_addr,
|
address=client_addr,
|
||||||
|
user_agent_ips=entries,
|
||||||
)
|
)
|
||||||
# Necessary due to auth checks prior to the threepid being
|
# Necessary due to auth checks prior to the threepid being
|
||||||
# written to the db
|
# written to the db
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -25,6 +26,16 @@ if MYPY:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationBehaviour(Enum):
|
||||||
|
"""
|
||||||
|
Enum to define whether a registration request should allowed, denied, or shadow-banned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALLOW = "allow"
|
||||||
|
SHADOW_BAN = "shadow_ban"
|
||||||
|
DENY = "deny"
|
||||||
|
|
||||||
|
|
||||||
class SpamCheckerApi(object):
|
class SpamCheckerApi(object):
|
||||||
"""A proxy object that gets passed to spam checkers so they can get
|
"""A proxy object that gets passed to spam checkers so they can get
|
||||||
access to rooms and other relevant information.
|
access to rooms and other relevant information.
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- A table of the IP address and user-agent used to complete each step of a
|
||||||
|
-- user-interactive authentication session.
|
||||||
|
CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
ip TEXT NOT NULL,
|
||||||
|
user_agent TEXT NOT NULL,
|
||||||
|
UNIQUE (session_id, ip, user_agent),
|
||||||
|
FOREIGN KEY (session_id)
|
||||||
|
REFERENCES ui_auth_sessions (session_id)
|
||||||
|
);
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return serverdict.get(key, default)
|
return serverdict.get(key, default)
|
||||||
|
|
||||||
|
async def add_user_agent_ip_to_ui_auth_session(
|
||||||
|
self, session_id: str, user_agent: str, ip: str,
|
||||||
|
):
|
||||||
|
"""Add the given user agent / IP to the tracking table
|
||||||
|
"""
|
||||||
|
await self.db_pool.simple_upsert(
|
||||||
|
table="ui_auth_sessions_ips",
|
||||||
|
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
|
||||||
|
values={},
|
||||||
|
desc="add_user_agent_ip_to_ui_auth_session",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_agents_ips_to_ui_auth_session(
|
||||||
|
self, session_id: str,
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
|
"""Get the given user agents / IPs used during the ui auth process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user_agent/ip pairs
|
||||||
|
"""
|
||||||
|
rows = await self.db_pool.simple_select_list(
|
||||||
|
table="ui_auth_sessions_ips",
|
||||||
|
keyvalues={"session_id": session_id},
|
||||||
|
retcols=("user_agent", "ip"),
|
||||||
|
desc="get_user_agents_ips_to_ui_auth_session",
|
||||||
|
)
|
||||||
|
return [(row["user_agent"], row["ip"]) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
class UIAuthStore(UIAuthWorkerStore):
|
class UIAuthStore(UIAuthWorkerStore):
|
||||||
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
||||||
|
@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore):
|
||||||
txn.execute(sql, [expiration_time])
|
txn.execute(sql, [expiration_time])
|
||||||
session_ids = [r[0] for r in txn.fetchall()]
|
session_ids = [r[0] for r in txn.fetchall()]
|
||||||
|
|
||||||
|
# Delete the corresponding IP/user agents.
|
||||||
|
self.db_pool.simple_delete_many_txn(
|
||||||
|
txn,
|
||||||
|
table="ui_auth_sessions_ips",
|
||||||
|
column="session_id",
|
||||||
|
iterable=session_ids,
|
||||||
|
keyvalues={},
|
||||||
|
)
|
||||||
|
|
||||||
# Delete the corresponding completed credentials.
|
# Delete the corresponding completed credentials.
|
||||||
self.db_pool.simple_delete_many_txn(
|
self.db_pool.simple_delete_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
|
|
@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||||
self.handler._auth_handler.complete_sso_login = simple_async_mock()
|
self.handler._auth_handler.complete_sso_login = simple_async_mock()
|
||||||
request = Mock(spec=["args", "getCookie", "addCookie"])
|
request = Mock(
|
||||||
|
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
|
||||||
|
)
|
||||||
|
|
||||||
code = "code"
|
code = "code"
|
||||||
state = "state"
|
state = "state"
|
||||||
nonce = "nonce"
|
nonce = "nonce"
|
||||||
client_redirect_url = "http://client/redirect"
|
client_redirect_url = "http://client/redirect"
|
||||||
|
user_agent = "Browser"
|
||||||
|
ip_address = "10.0.0.1"
|
||||||
session = self.handler._generate_oidc_session_token(
|
session = self.handler._generate_oidc_session_token(
|
||||||
state=state,
|
state=state,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
|
@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
request.args[b"code"] = [code.encode("utf-8")]
|
request.args[b"code"] = [code.encode("utf-8")]
|
||||||
request.args[b"state"] = [state.encode("utf-8")]
|
request.args[b"state"] = [state.encode("utf-8")]
|
||||||
|
|
||||||
|
request.requestHeaders = Mock(spec=["getRawHeaders"])
|
||||||
|
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
|
||||||
|
request.getClientIP.return_value = ip_address
|
||||||
|
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.handler._exchange_code.assert_called_once_with(code)
|
self.handler._exchange_code.assert_called_once_with(code)
|
||||||
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||||
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
|
self.handler._map_userinfo_to_user.assert_called_once_with(
|
||||||
|
userinfo, token, user_agent, ip_address
|
||||||
|
)
|
||||||
self.handler._fetch_userinfo.assert_not_called()
|
self.handler._fetch_userinfo.assert_not_called()
|
||||||
self.handler._render_error.assert_not_called()
|
self.handler._render_error.assert_not_called()
|
||||||
|
|
||||||
|
@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.handler._exchange_code.assert_called_once_with(code)
|
self.handler._exchange_code.assert_called_once_with(code)
|
||||||
self.handler._parse_id_token.assert_not_called()
|
self.handler._parse_id_token.assert_not_called()
|
||||||
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
|
self.handler._map_userinfo_to_user.assert_called_once_with(
|
||||||
|
userinfo, token, user_agent, ip_address
|
||||||
|
)
|
||||||
self.handler._fetch_userinfo.assert_called_once_with(token)
|
self.handler._fetch_userinfo.assert_called_once_with(token)
|
||||||
self.handler._render_error.assert_not_called()
|
self.handler._render_error.assert_not_called()
|
||||||
|
|
||||||
|
|
|
@ -17,18 +17,21 @@ from mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.auth import Auth
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
|
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
|
from synapse.spam_checker_api import RegistrationBehaviour
|
||||||
from synapse.types import RoomAlias, UserID, create_requester
|
from synapse.types import RoomAlias, UserID, create_requester
|
||||||
|
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
from tests.utils import mock_getRawHeaders
|
||||||
|
|
||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
|
||||||
|
|
||||||
class RegistrationHandlers(object):
|
class RegistrationHandlers:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.registration_handler = RegistrationHandler(hs)
|
self.registration_handler = RegistrationHandler(hs)
|
||||||
|
|
||||||
|
@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_spam_checker_deny(self):
|
||||||
|
"""A spam checker can deny registration, which results in an error."""
|
||||||
|
|
||||||
|
class DenyAll:
|
||||||
|
def check_registration_for_spam(
|
||||||
|
self, email_threepid, username, request_info
|
||||||
|
):
|
||||||
|
return RegistrationBehaviour.DENY
|
||||||
|
|
||||||
|
# Configure a spam checker that denies all users.
|
||||||
|
spam_checker = self.hs.get_spam_checker()
|
||||||
|
spam_checker.spam_checkers = [DenyAll()]
|
||||||
|
|
||||||
|
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
|
||||||
|
|
||||||
|
def test_spam_checker_shadow_ban(self):
|
||||||
|
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
|
||||||
|
|
||||||
|
class BanAll:
|
||||||
|
def check_registration_for_spam(
|
||||||
|
self, email_threepid, username, request_info
|
||||||
|
):
|
||||||
|
return RegistrationBehaviour.SHADOW_BAN
|
||||||
|
|
||||||
|
# Configure a spam checker that denies all users.
|
||||||
|
spam_checker = self.hs.get_spam_checker()
|
||||||
|
spam_checker.spam_checkers = [BanAll()]
|
||||||
|
|
||||||
|
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
||||||
|
|
||||||
|
# Get an access token.
|
||||||
|
token = self.macaroon_generator.generate_access_token(user_id)
|
||||||
|
self.get_success(
|
||||||
|
self.store.add_access_token_to_user(
|
||||||
|
user_id=user_id, token=token, device_id=None, valid_until_ms=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure the user was marked as shadow-banned.
|
||||||
|
request = Mock(args={})
|
||||||
|
request.args[b"access_token"] = [token.encode("ascii")]
|
||||||
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
auth = Auth(self.hs)
|
||||||
|
requester = self.get_success(auth.get_user_by_req(request))
|
||||||
|
|
||||||
|
self.assertTrue(requester.shadow_banned)
|
||||||
|
|
||||||
async def get_or_create_user(
|
async def get_or_create_user(
|
||||||
self, requester, localpart, displayname, password_hash=None
|
self, requester, localpart, displayname, password_hash=None
|
||||||
):
|
):
|
||||||
|
|
|
@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_spam_checker(self):
|
def test_spam_checker(self):
|
||||||
"""
|
"""
|
||||||
A user which fails to the spam checks will not appear in search results.
|
A user which fails the spam checks will not appear in search results.
|
||||||
"""
|
"""
|
||||||
u1 = self.register_user("user1", "pass")
|
u1 = self.register_user("user1", "pass")
|
||||||
u1_token = self.login(u1, "pass")
|
u1_token = self.login(u1, "pass")
|
||||||
|
@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
# Configure a spam checker that does not filter any users.
|
# Configure a spam checker that does not filter any users.
|
||||||
spam_checker = self.hs.get_spam_checker()
|
spam_checker = self.hs.get_spam_checker()
|
||||||
|
|
||||||
class AllowAll(object):
|
class AllowAll:
|
||||||
def check_username_for_spam(self, user_profile):
|
def check_username_for_spam(self, user_profile):
|
||||||
# Allow all users.
|
# Allow all users.
|
||||||
return False
|
return False
|
||||||
|
@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(s["results"]), 1)
|
self.assertEqual(len(s["results"]), 1)
|
||||||
|
|
||||||
# Configure a spam checker that filters all users.
|
# Configure a spam checker that filters all users.
|
||||||
class BlockAll(object):
|
class BlockAll:
|
||||||
def check_username_for_spam(self, user_profile):
|
def check_username_for_spam(self, user_profile):
|
||||||
# All users are spammy.
|
# All users are spammy.
|
||||||
return True
|
return True
|
||||||
|
|
Loading…
Reference in New Issue