From e46d5f3586025a491d11a31ce2be4c540c38d404 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 5 May 2023 15:06:22 +0100 Subject: [PATCH] Factor out an `is_mine_server_name` method (#15542) Add an `is_mine_server_name` method, similar to `is_mine_id`. Ideally we would use this consistently, instead of sometimes comparing against `hs.hostname` and other times reaching into `hs.config.server.server_name`. Also fix a bug in the tests where `hs.hostname` would sometimes differ from `hs.config.server.server_name`. Signed-off-by: Sean Quah --- changelog.d/15542.misc | 1 + synapse/api/auth_blocking.py | 4 ++-- synapse/crypto/keyring.py | 4 ++-- synapse/federation/federation_base.py | 2 +- synapse/federation/federation_client.py | 4 ++-- synapse/federation/federation_server.py | 3 ++- synapse/federation/send_queue.py | 3 ++- synapse/federation/sender/__init__.py | 11 ++++++----- synapse/federation/transport/client.py | 4 ++-- synapse/federation/transport/server/_base.py | 5 ++++- synapse/handlers/event_auth.py | 5 +++-- synapse/handlers/federation.py | 3 ++- synapse/handlers/federation_event.py | 3 ++- synapse/handlers/profile.py | 4 ++-- synapse/handlers/sso.py | 3 ++- synapse/handlers/typing.py | 3 ++- synapse/rest/admin/media.py | 4 ++-- synapse/rest/client/room.py | 4 ++-- synapse/rest/media/download_resource.py | 4 ++-- synapse/rest/media/thumbnail_resource.py | 4 ++-- synapse/server.py | 4 ++++ synapse/storage/databases/main/room.py | 2 +- tests/unittest.py | 16 ++++++++++++++-- 23 files changed, 64 insertions(+), 36 deletions(-) create mode 100644 changelog.d/15542.misc diff --git a/changelog.d/15542.misc b/changelog.d/15542.misc new file mode 100644 index 0000000000..32e3d678a1 --- /dev/null +++ b/changelog.d/15542.misc @@ -0,0 +1 @@ +Factor out an `is_mine_server_name` method. diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 22348d2d86..fcf5b842c6 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -39,7 +39,7 @@ class AuthBlocking: self._mau_limits_reserved_threepids = ( hs.config.server.mau_limits_reserved_threepids ) - self._server_name = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips async def check_auth_blocking( @@ -77,7 +77,7 @@ class AuthBlocking: if requester: if requester.authenticated_entity.startswith("@"): user_id = requester.authenticated_entity - elif requester.authenticated_entity == self._server_name: + elif self._is_mine_server_name(requester.authenticated_entity): # We never block the server from doing actions on behalf of # users. return diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index afdf6863d6..260aab3241 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -173,7 +173,7 @@ class Keyring: process_batch_callback=self._inner_fetch_key_requests, ) - self._hostname = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name # build a FetchKeyResult for each of our own keys, to shortcircuit the # fetcher. @@ -277,7 +277,7 @@ class Keyring: # If we are the originating server, short-circuit the key-fetch for any keys # we already have - if verify_request.server_name == self._hostname: + if self._is_mine_server_name(verify_request.server_name): for key_id in verify_request.key_ids: if key_id in self._local_verify_keys: found_keys[key_id] = self._local_verify_keys[key_id] diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 3df975958d..b77022b406 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -49,7 +49,7 @@ class FederationBase: def __init__(self, hs: "HomeServer"): self.hs = hs - self.server_name = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name self.keyring = hs.get_keyring() self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self.store = hs.get_datastores().main diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 0b2d1a78f7..076b9287c6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -854,7 +854,7 @@ class FederationClient(FederationBase): for destination in destinations: # We don't want to ask our own server for information we don't have - if destination == self.server_name: + if self._is_mine_server_name(destination): continue try: @@ -1536,7 +1536,7 @@ class FederationClient(FederationBase): self, destinations: Iterable[str], room_id: str, event_dict: JsonDict ) -> None: for destination in destinations: - if destination == self.server_name: + if self._is_mine_server_name(destination): continue try: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ca43c7bfc0..c590d8f96f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -129,6 +129,7 @@ class FederationServer(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) + self.server_name = hs.hostname self.handler = hs.get_federation_handler() self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self._federation_event_handler = hs.get_federation_event_handler() @@ -942,7 +943,7 @@ class FederationServer(FederationBase): authorising_server = get_domain_from_id( event.content[EventContentFields.AUTHORISING_USER] ) - if authorising_server != self.server_name: + if not self._is_mine_server_name(authorising_server): raise SynapseError( 400, f"Cannot authorise request from resident server: {authorising_server}", diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0b7c81677e..fb448f2155 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -68,6 +68,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id + self.is_mine_server_name = hs.is_mine_server_name # We may have multiple federation sender instances, so we need to track # their positions separately. @@ -198,7 +199,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): key: Optional[Hashable] = None, ) -> None: """As per FederationSender""" - if destination == self.server_name: + if self.is_mine_server_name(destination): logger.info("Not sending EDU to ourselves") return diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index edc4b1768c..f3bdc5a4d2 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -362,6 +362,7 @@ class FederationSender(AbstractFederationSender): self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id + self.is_mine_server_name = hs.is_mine_server_name self._presence_router: Optional["PresenceRouter"] = None self._transaction_manager = TransactionManager(hs) @@ -766,7 +767,7 @@ class FederationSender(AbstractFederationSender): domains = [ d for d in domains_set - if d != self.server_name + if not self.is_mine_server_name(d) and self._federation_shard_config.should_handle(self._instance_name, d) ] if not domains: @@ -832,7 +833,7 @@ class FederationSender(AbstractFederationSender): assert self.is_mine_id(state.user_id) for destination in destinations: - if destination == self.server_name: + if self.is_mine_server_name(destination): continue if not self._federation_shard_config.should_handle( self._instance_name, destination @@ -860,7 +861,7 @@ class FederationSender(AbstractFederationSender): content: content of EDU key: clobbering key for this edu """ - if destination == self.server_name: + if self.is_mine_server_name(destination): logger.info("Not sending EDU to ourselves") return @@ -897,7 +898,7 @@ class FederationSender(AbstractFederationSender): queue.send_edu(edu) def send_device_messages(self, destination: str, immediate: bool = True) -> None: - if destination == self.server_name: + if self.is_mine_server_name(destination): logger.warning("Not sending device update to ourselves") return @@ -919,7 +920,7 @@ class FederationSender(AbstractFederationSender): might have come back. """ - if destination == self.server_name: + if self.is_mine_server_name(destination): logger.warning("Not waking up ourselves") return diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index bc70b94f68..d2fa9976da 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -58,9 +58,9 @@ class TransportLayerClient: """Sends federation HTTP requests to other servers""" def __init__(self, hs: "HomeServer"): - self.server_name = hs.hostname self.client = hs.get_federation_http_client() self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled + self._is_mine_server_name = hs.is_mine_server_name async def get_room_state_ids( self, destination: str, room_id: str, event_id: str @@ -235,7 +235,7 @@ class TransportLayerClient: transaction.transaction_id, ) - if transaction.destination == self.server_name: + if self._is_mine_server_name(transaction.destination): raise RuntimeError("Transport layer cannot send to itself!") # FIXME: This is only used by the tests. The actual json sent is diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index cdaf0d5de7..b6e9c58760 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -57,6 +57,7 @@ class Authenticator: self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name self.store = hs.get_datastores().main self.federation_domain_whitelist = ( hs.config.federation.federation_domain_whitelist @@ -100,7 +101,9 @@ class Authenticator: json_request["signatures"].setdefault(origin, {})[key] = sig # if the origin_server sent a destination along it needs to match our own server_name - if destination is not None and destination != self.server_name: + if destination is not None and not self._is_mine_server_name( + destination + ): raise AuthenticationError( HTTPStatus.UNAUTHORIZED, "Destination mismatch in auth header", diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 0db0bd7304..3e37c0cbe2 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -29,7 +29,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.types import StateMap, StrCollection, get_domain_from_id +from synapse.types import StateMap, StrCollection if TYPE_CHECKING: from synapse.server import HomeServer @@ -47,6 +47,7 @@ class EventAuthHandler: self._store = hs.get_datastores().main self._state_storage_controller = hs.get_storage_controllers().state self._server_name = hs.hostname + self._is_mine_id = hs.is_mine_id async def check_auth_rules_from_context( self, @@ -247,7 +248,7 @@ class EventAuthHandler: if not await self.is_user_in_rooms(allowed_rooms, user_id): # If this is a remote request, the user might be in an allowed room # that we do not know about. - if get_domain_from_id(user_id) != self._server_name: + if not self._is_mine_id(user_id): for room_id in allowed_rooms: if not await self._store.is_host_joined(room_id, self._server_name): raise SynapseError( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4ad808a5b4..19dec4812f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -141,6 +141,7 @@ class FederationHandler: self.server_name = hs.hostname self.keyring = hs.get_keyring() self.is_mine_id = hs.is_mine_id + self.is_mine_server_name = hs.is_mine_server_name self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self.event_creation_handler = hs.get_event_creation_handler() self.event_builder_factory = hs.get_event_builder_factory() @@ -453,7 +454,7 @@ class FederationHandler: for dom in domains: # We don't want to ask our own server for information we don't have - if dom == self.server_name: + if self.is_mine_server_name(dom): continue try: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index fc15024166..06343d40e4 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -163,6 +163,7 @@ class FederationEventHandler: self._notifier = hs.get_notifier() self._is_mine_id = hs.is_mine_id + self._is_mine_server_name = hs.is_mine_server_name self._server_name = hs.hostname self._instance_name = hs.get_instance_name() @@ -688,7 +689,7 @@ class FederationEventHandler: server from invalid events (there is probably no point in trying to re-fetch invalid events from every other HS in the room.) """ - if dest == self._server_name: + if self._is_mine_server_name(dest): raise SynapseError(400, "Can't backfill from self.") events = await self._federation_client.backfill( diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 983b9b66fb..48f9858931 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -59,7 +59,7 @@ class ProfileHandler: self.max_avatar_size = hs.config.server.max_avatar_size self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes - self.server_name = hs.config.server.server_name + self._is_mine_server_name = hs.is_mine_server_name self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules @@ -309,7 +309,7 @@ class ProfileHandler: else: server_name = host - if server_name == self.server_name: + if self._is_mine_server_name(server_name): media_info = await self.store.get_local_media(media_id) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index c28325323c..92c3742625 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -194,6 +194,7 @@ class SsoHandler: self._clock = hs.get_clock() self._store = hs.get_datastores().main self._server_name = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -802,7 +803,7 @@ class SsoHandler: if profile["avatar_url"] is not None: server_name = profile["avatar_url"].split("/")[-2] media_id = profile["avatar_url"].split("/")[-1] - if server_name == self._server_name: + if self._is_mine_server_name(server_name): media = await self._media_repo.store.get_local_media(media_id) if media is not None and upload_name == media["upload_name"]: logger.info("skipping saving the user avatar") diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 39ae44ea95..7aeae5319c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,6 +68,7 @@ class FollowerTypingHandler: self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id + self.is_mine_server_name = hs.is_mine_server_name self.federation = None if hs.should_send_federation(): @@ -153,7 +154,7 @@ class FollowerTypingHandler: member.room_id ) for domain in hosts: - if domain != self.server_name: + if not self.is_mine_server_name(domain): logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( destination=domain, diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index c134ccfb3d..b7637dff0b 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -258,7 +258,7 @@ class DeleteMediaByID(RestServlet): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.auth = hs.get_auth() - self.server_name = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name self.media_repository = hs.get_media_repository() async def on_DELETE( @@ -266,7 +266,7 @@ class DeleteMediaByID(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if self.server_name != server_name: + if not self._is_mine_server_name(server_name): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") if await self.store.get_local_media(media_id) is None: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 7699cc8d1b..951bd033f5 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -501,7 +501,7 @@ class PublicRoomListRestServlet(RestServlet): limit = None handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server.server_name: + if server and not self.hs.is_mine_server_name(server): # Ensure the server is valid. try: parse_and_validate_server_name(server) @@ -551,7 +551,7 @@ class PublicRoomListRestServlet(RestServlet): limit = None handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server.server_name: + if server and not self.hs.is_mine_server_name(server): # Ensure the server is valid. try: parse_and_validate_server_name(server) diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py index 8f270cf4cc..3c618ef60a 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo - self.server_name = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name async def _async_render_GET(self, request: SynapseRequest) -> None: set_cors_headers(request) @@ -59,7 +59,7 @@ class DownloadResource(DirectServeJsonResource): b"no-referrer", ) server_name, media_id, name = parse_media_id(request) - if server_name == self.server_name: + if self._is_mine_server_name(server_name): await self.media_repo.get_local_media(request, media_id, name) else: allow_remote = parse_boolean(request, "allow_remote", default=True) diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 4ee2a0dbda..a6396fb05a 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -59,7 +59,7 @@ class ThumbnailResource(DirectServeJsonResource): self.media_repo = media_repo self.media_storage = media_storage self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails - self.server_name = hs.hostname + self._is_mine_server_name = hs.is_mine_server_name async def _async_render_GET(self, request: SynapseRequest) -> None: set_cors_headers(request) @@ -71,7 +71,7 @@ class ThumbnailResource(DirectServeJsonResource): # TODO Parse the Accept header to get an prioritised list of thumbnail types. m_type = "image/png" - if server_name == self.server_name: + if self._is_mine_server_name(server_name): if self.dynamic_thumbnails: await self._select_or_generate_local_thumbnail( request, media_id, width, height, method, m_type diff --git a/synapse/server.py b/synapse/server.py index c557c60482..fd29c28173 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -377,6 +377,10 @@ class HomeServer(metaclass=abc.ABCMeta): return False return localpart_hostname[1] == self.hostname + def is_mine_server_name(self, server_name: str) -> bool: + """Determines whether a server name refers to this homeserver.""" + return server_name == self.hostname + @cache_in_self def get_clock(self) -> Clock: return Clock(self._reactor) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index dd7dbb6901..ca8be8c80d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -996,7 +996,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): If it is `None` media will be removed from quarantine """ logger.info("Quarantining media: %s/%s", server_name, media_id) - is_local = server_name == self.config.server.server_name + is_local = self.hs.is_mine_server_name(server_name) def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: local_mxcs = [media_id] if is_local else [] diff --git a/tests/unittest.py b/tests/unittest.py index ee2f78ab01..b6fdf69635 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -566,7 +566,9 @@ class HomeserverTestCase(TestCase): client_ip, ) - def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: + def setup_test_homeserver( + self, name: Optional[str] = None, **kwargs: Any + ) -> HomeServer: """ Set up the test homeserver, meant to be called by the overridable make_homeserver. It automatically passes through the test class's @@ -585,15 +587,25 @@ class HomeserverTestCase(TestCase): else: config = kwargs["config"] + # The server name can be specified using either the `name` argument or a config + # override. The `name` argument takes precedence over any config overrides. + if name is not None: + config["server_name"] = name + # Parse the config from a config dict into a HomeServerConfig config_obj = make_homeserver_config_obj(config) kwargs["config"] = config_obj + # The server name in the config is now `name`, if provided, or the `server_name` + # from a config override, or the default of "test". Whichever it is, we + # construct a homeserver with a matching name. + kwargs["name"] = config_obj.server.server_name + async def run_bg_updates() -> None: with LoggingContext("run_bg_updates"): self.get_success(stor.db_pool.updates.run_background_updates(False)) - hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) + hs = setup_test_homeserver(self.addCleanup, **kwargs) stor = hs.get_datastores().main # Run the database background updates, when running against "master".