Factor out an `is_mine_server_name` method (#15542)
Add an `is_mine_server_name` method, similar to `is_mine_id`. Ideally we would use this consistently, instead of sometimes comparing against `hs.hostname` and other times reaching into `hs.config.server.server_name`. Also fix a bug in the tests where `hs.hostname` would sometimes differ from `hs.config.server.server_name`. Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
parent
83e7fa5eee
commit
e46d5f3586
|
@ -0,0 +1 @@
|
|||
Factor out an `is_mine_server_name` method.
|
|
@ -39,7 +39,7 @@ class AuthBlocking:
|
|||
self._mau_limits_reserved_threepids = (
|
||||
hs.config.server.mau_limits_reserved_threepids
|
||||
)
|
||||
self._server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
|
||||
|
||||
async def check_auth_blocking(
|
||||
|
@ -77,7 +77,7 @@ class AuthBlocking:
|
|||
if requester:
|
||||
if requester.authenticated_entity.startswith("@"):
|
||||
user_id = requester.authenticated_entity
|
||||
elif requester.authenticated_entity == self._server_name:
|
||||
elif self._is_mine_server_name(requester.authenticated_entity):
|
||||
# We never block the server from doing actions on behalf of
|
||||
# users.
|
||||
return
|
||||
|
|
|
@ -173,7 +173,7 @@ class Keyring:
|
|||
process_batch_callback=self._inner_fetch_key_requests,
|
||||
)
|
||||
|
||||
self._hostname = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
# build a FetchKeyResult for each of our own keys, to shortcircuit the
|
||||
# fetcher.
|
||||
|
@ -277,7 +277,7 @@ class Keyring:
|
|||
|
||||
# If we are the originating server, short-circuit the key-fetch for any keys
|
||||
# we already have
|
||||
if verify_request.server_name == self._hostname:
|
||||
if self._is_mine_server_name(verify_request.server_name):
|
||||
for key_id in verify_request.key_ids:
|
||||
if key_id in self._local_verify_keys:
|
||||
found_keys[key_id] = self._local_verify_keys[key_id]
|
||||
|
|
|
@ -49,7 +49,7 @@ class FederationBase:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
|
||||
self.server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self.keyring = hs.get_keyring()
|
||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||
self.store = hs.get_datastores().main
|
||||
|
|
|
@ -854,7 +854,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
for destination in destinations:
|
||||
# We don't want to ask our own server for information we don't have
|
||||
if destination == self.server_name:
|
||||
if self._is_mine_server_name(destination):
|
||||
continue
|
||||
|
||||
try:
|
||||
|
@ -1536,7 +1536,7 @@ class FederationClient(FederationBase):
|
|||
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
||||
) -> None:
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
if self._is_mine_server_name(destination):
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
|
@ -129,6 +129,7 @@ class FederationServer(FederationBase):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.server_name = hs.hostname
|
||||
self.handler = hs.get_federation_handler()
|
||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||
self._federation_event_handler = hs.get_federation_event_handler()
|
||||
|
@ -942,7 +943,7 @@ class FederationServer(FederationBase):
|
|||
authorising_server = get_domain_from_id(
|
||||
event.content[EventContentFields.AUTHORISING_USER]
|
||||
)
|
||||
if authorising_server != self.server_name:
|
||||
if not self._is_mine_server_name(authorising_server):
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"Cannot authorise request from resident server: {authorising_server}",
|
||||
|
|
|
@ -68,6 +68,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
|
|||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
# We may have multiple federation sender instances, so we need to track
|
||||
# their positions separately.
|
||||
|
@ -198,7 +199,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
|
|||
key: Optional[Hashable] = None,
|
||||
) -> None:
|
||||
"""As per FederationSender"""
|
||||
if destination == self.server_name:
|
||||
if self.is_mine_server_name(destination):
|
||||
logger.info("Not sending EDU to ourselves")
|
||||
return
|
||||
|
||||
|
|
|
@ -362,6 +362,7 @@ class FederationSender(AbstractFederationSender):
|
|||
|
||||
self.clock = hs.get_clock()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
self._presence_router: Optional["PresenceRouter"] = None
|
||||
self._transaction_manager = TransactionManager(hs)
|
||||
|
@ -766,7 +767,7 @@ class FederationSender(AbstractFederationSender):
|
|||
domains = [
|
||||
d
|
||||
for d in domains_set
|
||||
if d != self.server_name
|
||||
if not self.is_mine_server_name(d)
|
||||
and self._federation_shard_config.should_handle(self._instance_name, d)
|
||||
]
|
||||
if not domains:
|
||||
|
@ -832,7 +833,7 @@ class FederationSender(AbstractFederationSender):
|
|||
assert self.is_mine_id(state.user_id)
|
||||
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
if self.is_mine_server_name(destination):
|
||||
continue
|
||||
if not self._federation_shard_config.should_handle(
|
||||
self._instance_name, destination
|
||||
|
@ -860,7 +861,7 @@ class FederationSender(AbstractFederationSender):
|
|||
content: content of EDU
|
||||
key: clobbering key for this edu
|
||||
"""
|
||||
if destination == self.server_name:
|
||||
if self.is_mine_server_name(destination):
|
||||
logger.info("Not sending EDU to ourselves")
|
||||
return
|
||||
|
||||
|
@ -897,7 +898,7 @@ class FederationSender(AbstractFederationSender):
|
|||
queue.send_edu(edu)
|
||||
|
||||
def send_device_messages(self, destination: str, immediate: bool = True) -> None:
|
||||
if destination == self.server_name:
|
||||
if self.is_mine_server_name(destination):
|
||||
logger.warning("Not sending device update to ourselves")
|
||||
return
|
||||
|
||||
|
@ -919,7 +920,7 @@ class FederationSender(AbstractFederationSender):
|
|||
might have come back.
|
||||
"""
|
||||
|
||||
if destination == self.server_name:
|
||||
if self.is_mine_server_name(destination):
|
||||
logger.warning("Not waking up ourselves")
|
||||
return
|
||||
|
||||
|
|
|
@ -58,9 +58,9 @@ class TransportLayerClient:
|
|||
"""Sends federation HTTP requests to other servers"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.server_name = hs.hostname
|
||||
self.client = hs.get_federation_http_client()
|
||||
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
async def get_room_state_ids(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
|
@ -235,7 +235,7 @@ class TransportLayerClient:
|
|||
transaction.transaction_id,
|
||||
)
|
||||
|
||||
if transaction.destination == self.server_name:
|
||||
if self._is_mine_server_name(transaction.destination):
|
||||
raise RuntimeError("Transport layer cannot send to itself!")
|
||||
|
||||
# FIXME: This is only used by the tests. The actual json sent is
|
||||
|
|
|
@ -57,6 +57,7 @@ class Authenticator:
|
|||
self._clock = hs.get_clock()
|
||||
self.keyring = hs.get_keyring()
|
||||
self.server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self.store = hs.get_datastores().main
|
||||
self.federation_domain_whitelist = (
|
||||
hs.config.federation.federation_domain_whitelist
|
||||
|
@ -100,7 +101,9 @@ class Authenticator:
|
|||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||
|
||||
# if the origin_server sent a destination along it needs to match our own server_name
|
||||
if destination is not None and destination != self.server_name:
|
||||
if destination is not None and not self._is_mine_server_name(
|
||||
destination
|
||||
):
|
||||
raise AuthenticationError(
|
||||
HTTPStatus.UNAUTHORIZED,
|
||||
"Destination mismatch in auth header",
|
||||
|
|
|
@ -29,7 +29,7 @@ from synapse.event_auth import (
|
|||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.types import StateMap, StrCollection, get_domain_from_id
|
||||
from synapse.types import StateMap, StrCollection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -47,6 +47,7 @@ class EventAuthHandler:
|
|||
self._store = hs.get_datastores().main
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
self._server_name = hs.hostname
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
|
||||
async def check_auth_rules_from_context(
|
||||
self,
|
||||
|
@ -247,7 +248,7 @@ class EventAuthHandler:
|
|||
if not await self.is_user_in_rooms(allowed_rooms, user_id):
|
||||
# If this is a remote request, the user might be in an allowed room
|
||||
# that we do not know about.
|
||||
if get_domain_from_id(user_id) != self._server_name:
|
||||
if not self._is_mine_id(user_id):
|
||||
for room_id in allowed_rooms:
|
||||
if not await self._store.is_host_joined(room_id, self._server_name):
|
||||
raise SynapseError(
|
||||
|
|
|
@ -141,6 +141,7 @@ class FederationHandler:
|
|||
self.server_name = hs.hostname
|
||||
self.keyring = hs.get_keyring()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.is_mine_server_name = hs.is_mine_server_name
|
||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
@ -453,7 +454,7 @@ class FederationHandler:
|
|||
|
||||
for dom in domains:
|
||||
# We don't want to ask our own server for information we don't have
|
||||
if dom == self.server_name:
|
||||
if self.is_mine_server_name(dom):
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
|
@ -163,6 +163,7 @@ class FederationEventHandler:
|
|||
self._notifier = hs.get_notifier()
|
||||
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self._server_name = hs.hostname
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
|
@ -688,7 +689,7 @@ class FederationEventHandler:
|
|||
server from invalid events (there is probably no point in trying to
|
||||
re-fetch invalid events from every other HS in the room.)
|
||||
"""
|
||||
if dest == self._server_name:
|
||||
if self._is_mine_server_name(dest):
|
||||
raise SynapseError(400, "Can't backfill from self.")
|
||||
|
||||
events = await self._federation_client.backfill(
|
||||
|
|
|
@ -59,7 +59,7 @@ class ProfileHandler:
|
|||
self.max_avatar_size = hs.config.server.max_avatar_size
|
||||
self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
|
||||
|
||||
self.server_name = hs.config.server.server_name
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
|
||||
|
||||
|
@ -309,7 +309,7 @@ class ProfileHandler:
|
|||
else:
|
||||
server_name = host
|
||||
|
||||
if server_name == self.server_name:
|
||||
if self._is_mine_server_name(server_name):
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
else:
|
||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||
|
|
|
@ -194,6 +194,7 @@ class SsoHandler:
|
|||
self._clock = hs.get_clock()
|
||||
self._store = hs.get_datastores().main
|
||||
self._server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
@ -802,7 +803,7 @@ class SsoHandler:
|
|||
if profile["avatar_url"] is not None:
|
||||
server_name = profile["avatar_url"].split("/")[-2]
|
||||
media_id = profile["avatar_url"].split("/")[-1]
|
||||
if server_name == self._server_name:
|
||||
if self._is_mine_server_name(server_name):
|
||||
media = await self._media_repo.store.get_local_media(media_id)
|
||||
if media is not None and upload_name == media["upload_name"]:
|
||||
logger.info("skipping saving the user avatar")
|
||||
|
|
|
@ -68,6 +68,7 @@ class FollowerTypingHandler:
|
|||
self.server_name = hs.config.server.server_name
|
||||
self.clock = hs.get_clock()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
self.federation = None
|
||||
if hs.should_send_federation():
|
||||
|
@ -153,7 +154,7 @@ class FollowerTypingHandler:
|
|||
member.room_id
|
||||
)
|
||||
for domain in hosts:
|
||||
if domain != self.server_name:
|
||||
if not self.is_mine_server_name(domain):
|
||||
logger.debug("sending typing update to %s", domain)
|
||||
self.federation.build_and_send_edu(
|
||||
destination=domain,
|
||||
|
|
|
@ -258,7 +258,7 @@ class DeleteMediaByID(RestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self.auth = hs.get_auth()
|
||||
self.server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self.media_repository = hs.get_media_repository()
|
||||
|
||||
async def on_DELETE(
|
||||
|
@ -266,7 +266,7 @@ class DeleteMediaByID(RestServlet):
|
|||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if self.server_name != server_name:
|
||||
if not self._is_mine_server_name(server_name):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
|
||||
|
||||
if await self.store.get_local_media(media_id) is None:
|
||||
|
|
|
@ -501,7 +501,7 @@ class PublicRoomListRestServlet(RestServlet):
|
|||
limit = None
|
||||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
if server and server != self.hs.config.server.server_name:
|
||||
if server and not self.hs.is_mine_server_name(server):
|
||||
# Ensure the server is valid.
|
||||
try:
|
||||
parse_and_validate_server_name(server)
|
||||
|
@ -551,7 +551,7 @@ class PublicRoomListRestServlet(RestServlet):
|
|||
limit = None
|
||||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
if server and server != self.hs.config.server.server_name:
|
||||
if server and not self.hs.is_mine_server_name(server):
|
||||
# Ensure the server is valid.
|
||||
try:
|
||||
parse_and_validate_server_name(server)
|
||||
|
|
|
@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource):
|
|||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
set_cors_headers(request)
|
||||
|
@ -59,7 +59,7 @@ class DownloadResource(DirectServeJsonResource):
|
|||
b"no-referrer",
|
||||
)
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
if self._is_mine_server_name(server_name):
|
||||
await self.media_repo.get_local_media(request, media_id, name)
|
||||
else:
|
||||
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
||||
|
|
|
@ -59,7 +59,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
set_cors_headers(request)
|
||||
|
@ -71,7 +71,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
|
||||
m_type = "image/png"
|
||||
|
||||
if server_name == self.server_name:
|
||||
if self._is_mine_server_name(server_name):
|
||||
if self.dynamic_thumbnails:
|
||||
await self._select_or_generate_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type
|
||||
|
|
|
@ -377,6 +377,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
return False
|
||||
return localpart_hostname[1] == self.hostname
|
||||
|
||||
def is_mine_server_name(self, server_name: str) -> bool:
|
||||
"""Determines whether a server name refers to this homeserver."""
|
||||
return server_name == self.hostname
|
||||
|
||||
@cache_in_self
|
||||
def get_clock(self) -> Clock:
|
||||
return Clock(self._reactor)
|
||||
|
|
|
@ -996,7 +996,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
If it is `None` media will be removed from quarantine
|
||||
"""
|
||||
logger.info("Quarantining media: %s/%s", server_name, media_id)
|
||||
is_local = server_name == self.config.server.server_name
|
||||
is_local = self.hs.is_mine_server_name(server_name)
|
||||
|
||||
def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
|
||||
local_mxcs = [media_id] if is_local else []
|
||||
|
|
|
@ -566,7 +566,9 @@ class HomeserverTestCase(TestCase):
|
|||
client_ip,
|
||||
)
|
||||
|
||||
def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
|
||||
def setup_test_homeserver(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
) -> HomeServer:
|
||||
"""
|
||||
Set up the test homeserver, meant to be called by the overridable
|
||||
make_homeserver. It automatically passes through the test class's
|
||||
|
@ -585,15 +587,25 @@ class HomeserverTestCase(TestCase):
|
|||
else:
|
||||
config = kwargs["config"]
|
||||
|
||||
# The server name can be specified using either the `name` argument or a config
|
||||
# override. The `name` argument takes precedence over any config overrides.
|
||||
if name is not None:
|
||||
config["server_name"] = name
|
||||
|
||||
# Parse the config from a config dict into a HomeServerConfig
|
||||
config_obj = make_homeserver_config_obj(config)
|
||||
kwargs["config"] = config_obj
|
||||
|
||||
# The server name in the config is now `name`, if provided, or the `server_name`
|
||||
# from a config override, or the default of "test". Whichever it is, we
|
||||
# construct a homeserver with a matching name.
|
||||
kwargs["name"] = config_obj.server.server_name
|
||||
|
||||
async def run_bg_updates() -> None:
|
||||
with LoggingContext("run_bg_updates"):
|
||||
self.get_success(stor.db_pool.updates.run_background_updates(False))
|
||||
|
||||
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
||||
hs = setup_test_homeserver(self.addCleanup, **kwargs)
|
||||
stor = hs.get_datastores().main
|
||||
|
||||
# Run the database background updates, when running against "master".
|
||||
|
|
Loading…
Reference in New Issue