Uniformize spam-checker API, part 5: expand other spam-checker callbacks to return `Tuple[Codes, dict]` (#13044)

Signed-off-by: David Teller <davidt@element.io>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
This commit is contained in:
David Teller 2022-07-11 18:52:10 +02:00 committed by GitHub
parent d736d5cfad
commit 11f811470f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 426 additions and 60 deletions

1
changelog.d/13044.misc Normal file
View File

@ -0,0 +1 @@
Support temporary experimental return values for spam checker module callbacks.

View File

@ -297,8 +297,14 @@ class AuthError(SynapseError):
other poorly-defined times. other poorly-defined times.
""" """
def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN): def __init__(
super().__init__(code, msg, errcode) self,
code: int,
msg: str,
errcode: str = Codes.FORBIDDEN,
additional_fields: Optional[dict] = None,
):
super().__init__(code, msg, errcode, additional_fields)
class InvalidClientCredentialsError(SynapseError): class InvalidClientCredentialsError(SynapseError):

View File

@ -21,7 +21,6 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Collection, Collection,
Dict,
List, List,
Optional, Optional,
Tuple, Tuple,
@ -32,10 +31,11 @@ from typing import (
from typing_extensions import Literal from typing_extensions import Literal
import synapse import synapse
from synapse.api.errors import Codes
from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserProfile from synapse.types import JsonDict, RoomAlias, UserProfile
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -50,12 +50,12 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
str, str,
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may # Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing # disappear without warning depending on the results of ongoing
# experiments. # experiments.
# Use this to return additional information as part of an error. # Use this to return additional information as part of an error.
Tuple["synapse.api.errors.Codes", Dict], Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -70,7 +70,12 @@ USER_MAY_JOIN_ROOM_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
Literal["NOT_SPAM"], Literal["NOT_SPAM"],
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -81,7 +86,12 @@ USER_MAY_INVITE_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
Literal["NOT_SPAM"], Literal["NOT_SPAM"],
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -92,7 +102,12 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
Literal["NOT_SPAM"], Literal["NOT_SPAM"],
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -103,7 +118,12 @@ USER_MAY_CREATE_ROOM_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
Literal["NOT_SPAM"], Literal["NOT_SPAM"],
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -114,7 +134,12 @@ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
Literal["NOT_SPAM"], Literal["NOT_SPAM"],
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -125,7 +150,12 @@ USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
Literal["NOT_SPAM"], Literal["NOT_SPAM"],
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -154,7 +184,12 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
Awaitable[ Awaitable[
Union[ Union[
Literal["NOT_SPAM"], Literal["NOT_SPAM"],
"synapse.api.errors.Codes", Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated # Deprecated
bool, bool,
] ]
@ -345,7 +380,7 @@ class SpamChecker:
async def check_event_for_spam( async def check_event_for_spam(
self, event: "synapse.events.EventBase" self, event: "synapse.events.EventBase"
) -> Union[Tuple["synapse.api.errors.Codes", Dict], str]: ) -> Union[Tuple[Codes, JsonDict], str]:
"""Checks if a given event is considered "spammy" by this server. """Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if If the server considers an event spammy, then it will be rejected if
@ -376,7 +411,16 @@ class SpamChecker:
elif res is True: elif res is True:
# This spam-checker rejects the event with deprecated # This spam-checker rejects the event with deprecated
# return value `True` # return value `True`
return (synapse.api.errors.Codes.FORBIDDEN, {}) return synapse.api.errors.Codes.FORBIDDEN, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif not isinstance(res, str): elif not isinstance(res, str):
# mypy complains that we can't reach this code because of the # mypy complains that we can't reach this code because of the
# return type in CHECK_EVENT_FOR_SPAM_CALLBACK, but we don't know # return type in CHECK_EVENT_FOR_SPAM_CALLBACK, but we don't know
@ -422,7 +466,7 @@ class SpamChecker:
async def user_may_join_room( async def user_may_join_room(
self, user_id: str, room_id: str, is_invited: bool self, user_id: str, room_id: str, is_invited: bool
) -> Union["synapse.api.errors.Codes", Literal["NOT_SPAM"]]: ) -> Union[Tuple[Codes, JsonDict], Literal["NOT_SPAM"]]:
"""Checks if a given users is allowed to join a room. """Checks if a given users is allowed to join a room.
Not called when a user creates a room. Not called when a user creates a room.
@ -432,7 +476,7 @@ class SpamChecker:
is_invited: Whether the user is invited into the room is_invited: Whether the user is invited into the room
Returns: Returns:
NOT_SPAM if the operation is permitted, Codes otherwise. NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
""" """
for callback in self._user_may_join_room_callbacks: for callback in self._user_may_join_room_callbacks:
with Measure( with Measure(
@ -443,21 +487,28 @@ class SpamChecker:
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
elif res is False: elif res is False:
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes): elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res return res
else: else:
logger.warning( logger.warning(
"Module returned invalid value, rejecting join as spam" "Module returned invalid value, rejecting join as spam"
) )
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
# No spam-checker has rejected the request, let it pass. # No spam-checker has rejected the request, let it pass.
return self.NOT_SPAM return self.NOT_SPAM
async def user_may_invite( async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str self, inviter_userid: str, invitee_userid: str, room_id: str
) -> Union["synapse.api.errors.Codes", Literal["NOT_SPAM"]]: ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a given user may send an invite """Checks if a given user may send an invite
Args: Args:
@ -479,21 +530,28 @@ class SpamChecker:
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
elif res is False: elif res is False:
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes): elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res return res
else: else:
logger.warning( logger.warning(
"Module returned invalid value, rejecting invite as spam" "Module returned invalid value, rejecting invite as spam"
) )
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
# No spam-checker has rejected the request, let it pass. # No spam-checker has rejected the request, let it pass.
return self.NOT_SPAM return self.NOT_SPAM
async def user_may_send_3pid_invite( async def user_may_send_3pid_invite(
self, inviter_userid: str, medium: str, address: str, room_id: str self, inviter_userid: str, medium: str, address: str, room_id: str
) -> Union["synapse.api.errors.Codes", Literal["NOT_SPAM"]]: ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a given user may invite a given threepid into the room """Checks if a given user may invite a given threepid into the room
Note that if the threepid is already associated with a Matrix user ID, Synapse Note that if the threepid is already associated with a Matrix user ID, Synapse
@ -519,20 +577,27 @@ class SpamChecker:
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
elif res is False: elif res is False:
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes): elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res return res
else: else:
logger.warning( logger.warning(
"Module returned invalid value, rejecting 3pid invite as spam" "Module returned invalid value, rejecting 3pid invite as spam"
) )
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM return self.NOT_SPAM
async def user_may_create_room( async def user_may_create_room(
self, userid: str self, userid: str
) -> Union["synapse.api.errors.Codes", Literal["NOT_SPAM"]]: ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a given user may create a room """Checks if a given user may create a room
Args: Args:
@ -546,20 +611,27 @@ class SpamChecker:
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
elif res is False: elif res is False:
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes): elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res return res
else: else:
logger.warning( logger.warning(
"Module returned invalid value, rejecting room creation as spam" "Module returned invalid value, rejecting room creation as spam"
) )
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM return self.NOT_SPAM
async def user_may_create_room_alias( async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias self, userid: str, room_alias: RoomAlias
) -> Union["synapse.api.errors.Codes", Literal["NOT_SPAM"]]: ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a given user may create a room alias """Checks if a given user may create a room alias
Args: Args:
@ -575,20 +647,27 @@ class SpamChecker:
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
elif res is False: elif res is False:
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes): elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res return res
else: else:
logger.warning( logger.warning(
"Module returned invalid value, rejecting room create as spam" "Module returned invalid value, rejecting room create as spam"
) )
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM return self.NOT_SPAM
async def user_may_publish_room( async def user_may_publish_room(
self, userid: str, room_id: str self, userid: str, room_id: str
) -> Union["synapse.api.errors.Codes", Literal["NOT_SPAM"]]: ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a given user may publish a room to the directory """Checks if a given user may publish a room to the directory
Args: Args:
@ -603,14 +682,21 @@ class SpamChecker:
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
elif res is False: elif res is False:
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes): elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res return res
else: else:
logger.warning( logger.warning(
"Module returned invalid value, rejecting room publication as spam" "Module returned invalid value, rejecting room publication as spam"
) )
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM return self.NOT_SPAM
@ -678,7 +764,7 @@ class SpamChecker:
async def check_media_file_for_spam( async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> Union["synapse.api.errors.Codes", Literal["NOT_SPAM"]]: ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a piece of newly uploaded media should be blocked. """Checks if a piece of newly uploaded media should be blocked.
This will be called for local uploads, downloads of remote media, each This will be called for local uploads, downloads of remote media, each
@ -715,13 +801,20 @@ class SpamChecker:
if res is False or res is self.NOT_SPAM: if res is False or res is self.NOT_SPAM:
continue continue
elif res is True: elif res is True:
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes): elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res return res
else: else:
logger.warning( logger.warning(
"Module returned invalid value, rejecting media file as spam" "Module returned invalid value, rejecting media file as spam"
) )
return synapse.api.errors.Codes.FORBIDDEN return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM return self.NOT_SPAM

View File

@ -149,7 +149,8 @@ class DirectoryHandler:
raise AuthError( raise AuthError(
403, 403,
"This user is not permitted to create this alias", "This user is not permitted to create this alias",
spam_check, errcode=spam_check[0],
additional_fields=spam_check[1],
) )
if not self.config.roomdirectory.is_alias_creation_allowed( if not self.config.roomdirectory.is_alias_creation_allowed(
@ -441,7 +442,8 @@ class DirectoryHandler:
raise AuthError( raise AuthError(
403, 403,
"This user is not permitted to publish rooms to the room list", "This user is not permitted to publish rooms to the room list",
spam_check, errcode=spam_check[0],
additional_fields=spam_check[1],
) )
if requester.is_guest: if requester.is_guest:

View File

@ -844,7 +844,8 @@ class FederationHandler:
raise SynapseError( raise SynapseError(
403, 403,
"This user is not permitted to send invites to this server/user", "This user is not permitted to send invites to this server/user",
spam_check, errcode=spam_check[0],
additional_fields=spam_check[1],
) )
membership = event.content.get("membership") membership = event.content.get("membership")

View File

@ -440,7 +440,12 @@ class RoomCreationHandler:
spam_check = await self.spam_checker.user_may_create_room(user_id) spam_check = await self.spam_checker.user_may_create_room(user_id)
if spam_check != NOT_SPAM: if spam_check != NOT_SPAM:
raise SynapseError(403, "You are not permitted to create rooms", spam_check) raise SynapseError(
403,
"You are not permitted to create rooms",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
creation_content: JsonDict = { creation_content: JsonDict = {
"room_version": new_room_version.identifier, "room_version": new_room_version.identifier,
@ -731,7 +736,10 @@ class RoomCreationHandler:
spam_check = await self.spam_checker.user_may_create_room(user_id) spam_check = await self.spam_checker.user_may_create_room(user_id)
if spam_check != NOT_SPAM: if spam_check != NOT_SPAM:
raise SynapseError( raise SynapseError(
403, "You are not permitted to create rooms", spam_check 403,
"You are not permitted to create rooms",
errcode=spam_check[0],
additional_fields=spam_check[1],
) )
if ratelimit: if ratelimit:

View File

@ -685,7 +685,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if target_id == self._server_notices_mxid: if target_id == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite_code = None block_invite_result = None
if ( if (
self._server_notices_mxid is not None self._server_notices_mxid is not None
@ -703,18 +703,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"Blocking invite: user is not admin and non-admin " "Blocking invite: user is not admin and non-admin "
"invites disabled" "invites disabled"
) )
block_invite_code = Codes.FORBIDDEN block_invite_result = (Codes.FORBIDDEN, {})
spam_check = await self.spam_checker.user_may_invite( spam_check = await self.spam_checker.user_may_invite(
requester.user.to_string(), target_id, room_id requester.user.to_string(), target_id, room_id
) )
if spam_check != NOT_SPAM: if spam_check != NOT_SPAM:
logger.info("Blocking invite due to spam checker") logger.info("Blocking invite due to spam checker")
block_invite_code = spam_check block_invite_result = spam_check
if block_invite_code is not None: if block_invite_result is not None:
raise SynapseError( raise SynapseError(
403, "Invites have been disabled on this server", block_invite_code 403,
"Invites have been disabled on this server",
errcode=block_invite_result[0],
additional_fields=block_invite_result[1],
) )
# An empty prev_events list is allowed as long as the auth_event_ids are present # An empty prev_events list is allowed as long as the auth_event_ids are present
@ -828,7 +831,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
target.to_string(), room_id, is_invited=inviter is not None target.to_string(), room_id, is_invited=inviter is not None
) )
if spam_check != NOT_SPAM: if spam_check != NOT_SPAM:
raise SynapseError(403, "Not allowed to join this room", spam_check) raise SynapseError(
403,
"Not allowed to join this room",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
# Check if a remote join should be performed. # Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join( remote_join, remote_room_hosts = await self._should_perform_remote_join(
@ -1387,7 +1395,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id=room_id, room_id=room_id,
) )
if spam_check != NOT_SPAM: if spam_check != NOT_SPAM:
raise SynapseError(403, "Cannot send threepid invite", spam_check) raise SynapseError(
403,
"Cannot send threepid invite",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
stream_id = await self._make_and_store_3pid_invite( stream_id = await self._make_and_store_3pid_invite(
requester, requester,

View File

@ -35,6 +35,7 @@ from typing_extensions import ParamSpec
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api import errors
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.presence_router import ( from synapse.events.presence_router import (

View File

@ -154,7 +154,9 @@ class MediaStorage:
# Note that we'll delete the stored media, due to the # Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in # try/except below. The media also won't be stored in
# the DB. # the DB.
raise SpamMediaException(errcode=spam_check) # We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])
for provider in self.storage_providers: for provider in self.storage_providers:
await provider.store_file(path, file_info) await provider.store_file(path, file_info)

View File

@ -22,7 +22,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call from unittest.mock import Mock, call
from urllib import parse as urlparse from urllib import parse as urlparse
# `Literal` appears with Python 3.8. from parameterized import param, parameterized
from typing_extensions import Literal from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -815,14 +815,14 @@ class RoomsCreateTestCase(RoomBase):
In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
""" """
async def user_may_join_room( async def user_may_join_room_codes(
mxid: str, mxid: str,
room_id: str, room_id: str,
is_invite: bool, is_invite: bool,
) -> Codes: ) -> Codes:
return Codes.CONSENT_NOT_GIVEN return Codes.CONSENT_NOT_GIVEN
join_mock = Mock(side_effect=user_may_join_room) join_mock = Mock(side_effect=user_may_join_room_codes)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
channel = self.make_request( channel = self.make_request(
@ -834,6 +834,25 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(join_mock.call_count, 0) self.assertEqual(join_mock.call_count, 0)
# Now change the return value of the callback to deny any join. Since we're
# creating the room, despite the return value, we should be able to join.
async def user_may_join_room_tuple(
mxid: str,
room_id: str,
is_invite: bool,
) -> Tuple[Codes, dict]:
return Codes.INCOMPATIBLE_ROOM_VERSION, {}
join_mock.side_effect = user_may_join_room_tuple
channel = self.make_request(
"POST",
"/createRoom",
{},
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
class RoomTopicTestCase(RoomBase): class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events.""" """Tests /rooms/$room_id/topic REST events."""
@ -1113,13 +1132,15 @@ class RoomJoinTestCase(RoomBase):
""" """
# Register a dummy callback. Make it allow all room joins for now. # Register a dummy callback. Make it allow all room joins for now.
return_value: Union[Literal["NOT_SPAM"], Codes] = synapse.module_api.NOT_SPAM return_value: Union[
Literal["NOT_SPAM"], Tuple[Codes, dict], Codes
] = synapse.module_api.NOT_SPAM
async def user_may_join_room( async def user_may_join_room(
userid: str, userid: str,
room_id: str, room_id: str,
is_invited: bool, is_invited: bool,
) -> Union[Literal["NOT_SPAM"], Codes]: ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]:
return return_value return return_value
# `spec` argument is needed for this function mock to have `__qualname__`, which # `spec` argument is needed for this function mock to have `__qualname__`, which
@ -1163,8 +1184,28 @@ class RoomJoinTestCase(RoomBase):
) )
# Now make the callback deny all room joins, and check that a join actually fails. # Now make the callback deny all room joins, and check that a join actually fails.
# We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`.
return_value = Codes.CONSENT_NOT_GIVEN return_value = Codes.CONSENT_NOT_GIVEN
self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1)
self.helper.join(
self.room3,
self.user2,
expect_code=403,
expect_errcode=return_value,
tok=self.tok2,
)
# Now make the callback deny all room joins, and check that a join actually fails.
# As above, with the experimental extension that lets us return dictionaries.
return_value = (Codes.BAD_ALIAS, {"another_field": "12345"})
self.helper.join(
self.room3,
self.user2,
expect_code=403,
expect_errcode=return_value[0],
tok=self.tok2,
expect_additional_fields=return_value[1],
)
class RoomJoinRatelimitTestCase(RoomBase): class RoomJoinRatelimitTestCase(RoomBase):
@ -1314,6 +1355,97 @@ class RoomMessagesTestCase(RoomBase):
channel = self.make_request("PUT", path, content) channel = self.make_request("PUT", path, content)
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(200, channel.code, msg=channel.result["body"])
@parameterized.expand(
[
# Allow
param(
name="NOT_SPAM", value="NOT_SPAM", expected_code=200, expected_fields={}
),
param(name="False", value=False, expected_code=200, expected_fields={}),
# Block
param(
name="scalene string",
value="ANY OTHER STRING",
expected_code=403,
expected_fields={"errcode": "M_FORBIDDEN"},
),
param(
name="True",
value=True,
expected_code=403,
expected_fields={"errcode": "M_FORBIDDEN"},
),
param(
name="Code",
value=Codes.LIMIT_EXCEEDED,
expected_code=403,
expected_fields={"errcode": "M_LIMIT_EXCEEDED"},
),
param(
name="Tuple",
value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}),
expected_code=403,
expected_fields={
"errcode": "M_SERVER_NOT_TRUSTED",
"additional_field": "12345",
},
),
]
)
def test_spam_checker_check_event_for_spam(
self,
name: str,
value: Union[str, bool, Codes, Tuple[Codes, JsonDict]],
expected_code: int,
expected_fields: dict,
) -> None:
class SpamCheck:
mock_return_value: Union[
str, bool, Codes, Tuple[Codes, JsonDict], bool
] = "NOT_SPAM"
mock_content: Optional[JsonDict] = None
async def check_event_for_spam(
self,
event: synapse.events.EventBase,
) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]:
self.mock_content = event.content
return self.mock_return_value
spam_checker = SpamCheck()
self.hs.get_spam_checker()._check_event_for_spam_callbacks.append(
spam_checker.check_event_for_spam
)
# Inject `value` as mock_return_value
spam_checker.mock_return_value = value
path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % (
urlparse.quote(self.room_id),
urlparse.quote(name),
)
body = "test-%s" % name
content = '{"body":"%s","msgtype":"m.text"}' % body
channel = self.make_request("PUT", path, content)
# Check that the callback has witnessed the correct event.
self.assertIsNotNone(spam_checker.mock_content)
if (
spam_checker.mock_content is not None
): # Checked just above, but mypy doesn't know about that.
self.assertEqual(
spam_checker.mock_content["body"], body, spam_checker.mock_content
)
# Check that we have the correct result.
self.assertEqual(expected_code, channel.code, msg=channel.result["body"])
for expected_key, expected_value in expected_fields.items():
self.assertEqual(
channel.json_body.get(expected_key, None),
expected_value,
"Field %s absent or invalid " % expected_key,
)
class RoomPowerLevelOverridesTestCase(RoomBase): class RoomPowerLevelOverridesTestCase(RoomBase):
"""Tests that the power levels can be overridden with server config.""" """Tests that the power levels can be overridden with server config."""
@ -3235,7 +3367,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
make_invite_mock.assert_called_once() make_invite_mock.assert_called_once()
# Now change the return value of the callback to deny any invite and test that # Now change the return value of the callback to deny any invite and test that
# we can't send the invite. # we can't send the invite. We pick an arbitrary error code to be able to check
# that the same code has been returned
mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
channel = self.make_request( channel = self.make_request(
method="POST", method="POST",
@ -3249,6 +3382,27 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN)
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
# Run variant with `Tuple[Codes, dict]`.
mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
content={
"id_server": "example.com",
"id_access_token": "sometoken",
"medium": "email",
"address": email_to_invite,
},
access_token=self.tok,
)
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT)
self.assertEqual(channel.json_body["field"], "value")
# Also check that it stopped before calling _make_and_store_3pid_invite. # Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once() make_invite_mock.assert_called_once()

View File

@ -41,6 +41,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import Codes
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
@ -171,6 +172,8 @@ class RestHelper:
expect_code: int = HTTPStatus.OK, expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None, tok: Optional[str] = None,
appservice_user_id: Optional[str] = None, appservice_user_id: Optional[str] = None,
expect_errcode: Optional[Codes] = None,
expect_additional_fields: Optional[dict] = None,
) -> None: ) -> None:
self.change_membership( self.change_membership(
room=room, room=room,
@ -180,6 +183,8 @@ class RestHelper:
appservice_user_id=appservice_user_id, appservice_user_id=appservice_user_id,
membership=Membership.JOIN, membership=Membership.JOIN,
expect_code=expect_code, expect_code=expect_code,
expect_errcode=expect_errcode,
expect_additional_fields=expect_additional_fields,
) )
def knock( def knock(
@ -263,6 +268,7 @@ class RestHelper:
appservice_user_id: Optional[str] = None, appservice_user_id: Optional[str] = None,
expect_code: int = HTTPStatus.OK, expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None, expect_errcode: Optional[str] = None,
expect_additional_fields: Optional[dict] = None,
) -> None: ) -> None:
""" """
Send a membership state event into a room. Send a membership state event into a room.
@ -323,6 +329,21 @@ class RestHelper:
channel.result["body"], channel.result["body"],
) )
if expect_additional_fields is not None:
for expect_key, expect_value in expect_additional_fields.items():
assert expect_key in channel.json_body, "Expected field %s, got %s" % (
expect_key,
channel.json_body,
)
assert (
channel.json_body[expect_key] == expect_value
), "Expected: %s at %s, got: %s, resp: %s" % (
expect_value,
expect_key,
channel.json_body[expect_key],
channel.json_body,
)
self.auth_user_id = temp_id self.auth_user_id = temp_id
def send( def send(

View File

@ -23,11 +23,13 @@ from urllib import parse
import attr import attr
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
from PIL import Image as Image from PIL import Image as Image
from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -570,9 +572,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
) )
class TestSpamChecker: class TestSpamCheckerLegacy:
"""A spam checker module that rejects all media that includes the bytes """A spam checker module that rejects all media that includes the bytes
`evil`. `evil`.
Uses the legacy Spam-Checker API.
""" """
def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None: def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
@ -613,7 +617,7 @@ class TestSpamChecker:
return b"evil" in buf.getvalue() return b"evil" in buf.getvalue()
class SpamCheckerTestCase(unittest.HomeserverTestCase): class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
servlets = [ servlets = [
login.register_servlets, login.register_servlets,
admin.register_servlets, admin.register_servlets,
@ -637,7 +641,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
{ {
"spam_checker": [ "spam_checker": [
{ {
"module": TestSpamChecker.__module__ + ".TestSpamChecker", "module": TestSpamCheckerLegacy.__module__
+ ".TestSpamCheckerLegacy",
"config": {}, "config": {},
} }
] ]
@ -662,3 +667,62 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
self.helper.upload_media( self.helper.upload_media(
self.upload_resource, data, tok=self.tok, expect_code=400 self.upload_resource, data, tok=self.tok, expect_code=400
) )
EVIL_DATA = b"Some evil data"
EVIL_DATA_EXPERIMENT = b"Some evil data to trigger the experimental tuple API"
class SpamCheckerTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
admin.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
hs.get_module_api().register_spam_checker_callbacks(
check_media_file_for_spam=self.check_media_file_for_spam
)
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> Union[Codes, Literal["NOT_SPAM"]]:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
if buf.getvalue() == EVIL_DATA:
return Codes.FORBIDDEN
elif buf.getvalue() == EVIL_DATA_EXPERIMENT:
return (Codes.FORBIDDEN, {})
else:
return "NOT_SPAM"
def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker.
"""
self.helper.upload_media(
self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
)
self.helper.upload_media(
self.upload_resource,
EVIL_DATA_EXPERIMENT,
tok=self.tok,
expect_code=400,
)