Factor out an `is_mine_server_name` method (#15542)
Add an `is_mine_server_name` method, similar to `is_mine_id`. Ideally we would use this consistently, instead of sometimes comparing against `hs.hostname` and other times reaching into `hs.config.server.server_name`. Also fix a bug in the tests where `hs.hostname` would sometimes differ from `hs.config.server.server_name`. Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
parent
83e7fa5eee
commit
e46d5f3586
|
@ -0,0 +1 @@
|
||||||
|
Factor out an `is_mine_server_name` method.
|
|
@ -39,7 +39,7 @@ class AuthBlocking:
|
||||||
self._mau_limits_reserved_threepids = (
|
self._mau_limits_reserved_threepids = (
|
||||||
hs.config.server.mau_limits_reserved_threepids
|
hs.config.server.mau_limits_reserved_threepids
|
||||||
)
|
)
|
||||||
self._server_name = hs.hostname
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
|
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
|
||||||
|
|
||||||
async def check_auth_blocking(
|
async def check_auth_blocking(
|
||||||
|
@ -77,7 +77,7 @@ class AuthBlocking:
|
||||||
if requester:
|
if requester:
|
||||||
if requester.authenticated_entity.startswith("@"):
|
if requester.authenticated_entity.startswith("@"):
|
||||||
user_id = requester.authenticated_entity
|
user_id = requester.authenticated_entity
|
||||||
elif requester.authenticated_entity == self._server_name:
|
elif self._is_mine_server_name(requester.authenticated_entity):
|
||||||
# We never block the server from doing actions on behalf of
|
# We never block the server from doing actions on behalf of
|
||||||
# users.
|
# users.
|
||||||
return
|
return
|
||||||
|
|
|
@ -173,7 +173,7 @@ class Keyring:
|
||||||
process_batch_callback=self._inner_fetch_key_requests,
|
process_batch_callback=self._inner_fetch_key_requests,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._hostname = hs.hostname
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
# build a FetchKeyResult for each of our own keys, to shortcircuit the
|
# build a FetchKeyResult for each of our own keys, to shortcircuit the
|
||||||
# fetcher.
|
# fetcher.
|
||||||
|
@ -277,7 +277,7 @@ class Keyring:
|
||||||
|
|
||||||
# If we are the originating server, short-circuit the key-fetch for any keys
|
# If we are the originating server, short-circuit the key-fetch for any keys
|
||||||
# we already have
|
# we already have
|
||||||
if verify_request.server_name == self._hostname:
|
if self._is_mine_server_name(verify_request.server_name):
|
||||||
for key_id in verify_request.key_ids:
|
for key_id in verify_request.key_ids:
|
||||||
if key_id in self._local_verify_keys:
|
if key_id in self._local_verify_keys:
|
||||||
found_keys[key_id] = self._local_verify_keys[key_id]
|
found_keys[key_id] = self._local_verify_keys[key_id]
|
||||||
|
|
|
@ -49,7 +49,7 @@ class FederationBase:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.server_name = hs.hostname
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
|
@ -854,7 +854,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
# We don't want to ask our own server for information we don't have
|
# We don't want to ask our own server for information we don't have
|
||||||
if destination == self.server_name:
|
if self._is_mine_server_name(destination):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1536,7 +1536,7 @@ class FederationClient(FederationBase):
|
||||||
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
||||||
) -> None:
|
) -> None:
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
if destination == self.server_name:
|
if self._is_mine_server_name(destination):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -129,6 +129,7 @@ class FederationServer(FederationBase):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
|
self.server_name = hs.hostname
|
||||||
self.handler = hs.get_federation_handler()
|
self.handler = hs.get_federation_handler()
|
||||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||||
self._federation_event_handler = hs.get_federation_event_handler()
|
self._federation_event_handler = hs.get_federation_event_handler()
|
||||||
|
@ -942,7 +943,7 @@ class FederationServer(FederationBase):
|
||||||
authorising_server = get_domain_from_id(
|
authorising_server = get_domain_from_id(
|
||||||
event.content[EventContentFields.AUTHORISING_USER]
|
event.content[EventContentFields.AUTHORISING_USER]
|
||||||
)
|
)
|
||||||
if authorising_server != self.server_name:
|
if not self._is_mine_server_name(authorising_server):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
f"Cannot authorise request from resident server: {authorising_server}",
|
f"Cannot authorise request from resident server: {authorising_server}",
|
||||||
|
|
|
@ -68,6 +68,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
# We may have multiple federation sender instances, so we need to track
|
# We may have multiple federation sender instances, so we need to track
|
||||||
# their positions separately.
|
# their positions separately.
|
||||||
|
@ -198,7 +199,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
|
||||||
key: Optional[Hashable] = None,
|
key: Optional[Hashable] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""As per FederationSender"""
|
"""As per FederationSender"""
|
||||||
if destination == self.server_name:
|
if self.is_mine_server_name(destination):
|
||||||
logger.info("Not sending EDU to ourselves")
|
logger.info("Not sending EDU to ourselves")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -362,6 +362,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
self._presence_router: Optional["PresenceRouter"] = None
|
self._presence_router: Optional["PresenceRouter"] = None
|
||||||
self._transaction_manager = TransactionManager(hs)
|
self._transaction_manager = TransactionManager(hs)
|
||||||
|
@ -766,7 +767,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
domains = [
|
domains = [
|
||||||
d
|
d
|
||||||
for d in domains_set
|
for d in domains_set
|
||||||
if d != self.server_name
|
if not self.is_mine_server_name(d)
|
||||||
and self._federation_shard_config.should_handle(self._instance_name, d)
|
and self._federation_shard_config.should_handle(self._instance_name, d)
|
||||||
]
|
]
|
||||||
if not domains:
|
if not domains:
|
||||||
|
@ -832,7 +833,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
assert self.is_mine_id(state.user_id)
|
assert self.is_mine_id(state.user_id)
|
||||||
|
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
if destination == self.server_name:
|
if self.is_mine_server_name(destination):
|
||||||
continue
|
continue
|
||||||
if not self._federation_shard_config.should_handle(
|
if not self._federation_shard_config.should_handle(
|
||||||
self._instance_name, destination
|
self._instance_name, destination
|
||||||
|
@ -860,7 +861,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
content: content of EDU
|
content: content of EDU
|
||||||
key: clobbering key for this edu
|
key: clobbering key for this edu
|
||||||
"""
|
"""
|
||||||
if destination == self.server_name:
|
if self.is_mine_server_name(destination):
|
||||||
logger.info("Not sending EDU to ourselves")
|
logger.info("Not sending EDU to ourselves")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -897,7 +898,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
queue.send_edu(edu)
|
queue.send_edu(edu)
|
||||||
|
|
||||||
def send_device_messages(self, destination: str, immediate: bool = True) -> None:
|
def send_device_messages(self, destination: str, immediate: bool = True) -> None:
|
||||||
if destination == self.server_name:
|
if self.is_mine_server_name(destination):
|
||||||
logger.warning("Not sending device update to ourselves")
|
logger.warning("Not sending device update to ourselves")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -919,7 +920,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
might have come back.
|
might have come back.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if destination == self.server_name:
|
if self.is_mine_server_name(destination):
|
||||||
logger.warning("Not waking up ourselves")
|
logger.warning("Not waking up ourselves")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -58,9 +58,9 @@ class TransportLayerClient:
|
||||||
"""Sends federation HTTP requests to other servers"""
|
"""Sends federation HTTP requests to other servers"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.server_name = hs.hostname
|
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
||||||
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
async def get_room_state_ids(
|
async def get_room_state_ids(
|
||||||
self, destination: str, room_id: str, event_id: str
|
self, destination: str, room_id: str, event_id: str
|
||||||
|
@ -235,7 +235,7 @@ class TransportLayerClient:
|
||||||
transaction.transaction_id,
|
transaction.transaction_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if transaction.destination == self.server_name:
|
if self._is_mine_server_name(transaction.destination):
|
||||||
raise RuntimeError("Transport layer cannot send to itself!")
|
raise RuntimeError("Transport layer cannot send to itself!")
|
||||||
|
|
||||||
# FIXME: This is only used by the tests. The actual json sent is
|
# FIXME: This is only used by the tests. The actual json sent is
|
||||||
|
|
|
@ -57,6 +57,7 @@ class Authenticator:
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.federation_domain_whitelist = (
|
self.federation_domain_whitelist = (
|
||||||
hs.config.federation.federation_domain_whitelist
|
hs.config.federation.federation_domain_whitelist
|
||||||
|
@ -100,7 +101,9 @@ class Authenticator:
|
||||||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||||
|
|
||||||
# if the origin_server sent a destination along it needs to match our own server_name
|
# if the origin_server sent a destination along it needs to match our own server_name
|
||||||
if destination is not None and destination != self.server_name:
|
if destination is not None and not self._is_mine_server_name(
|
||||||
|
destination
|
||||||
|
):
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
HTTPStatus.UNAUTHORIZED,
|
HTTPStatus.UNAUTHORIZED,
|
||||||
"Destination mismatch in auth header",
|
"Destination mismatch in auth header",
|
||||||
|
|
|
@ -29,7 +29,7 @@ from synapse.event_auth import (
|
||||||
)
|
)
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.builder import EventBuilder
|
from synapse.events.builder import EventBuilder
|
||||||
from synapse.types import StateMap, StrCollection, get_domain_from_id
|
from synapse.types import StateMap, StrCollection
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -47,6 +47,7 @@ class EventAuthHandler:
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self._state_storage_controller = hs.get_storage_controllers().state
|
self._state_storage_controller = hs.get_storage_controllers().state
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
self._is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
async def check_auth_rules_from_context(
|
async def check_auth_rules_from_context(
|
||||||
self,
|
self,
|
||||||
|
@ -247,7 +248,7 @@ class EventAuthHandler:
|
||||||
if not await self.is_user_in_rooms(allowed_rooms, user_id):
|
if not await self.is_user_in_rooms(allowed_rooms, user_id):
|
||||||
# If this is a remote request, the user might be in an allowed room
|
# If this is a remote request, the user might be in an allowed room
|
||||||
# that we do not know about.
|
# that we do not know about.
|
||||||
if get_domain_from_id(user_id) != self._server_name:
|
if not self._is_mine_id(user_id):
|
||||||
for room_id in allowed_rooms:
|
for room_id in allowed_rooms:
|
||||||
if not await self._store.is_host_joined(room_id, self._server_name):
|
if not await self._store.is_host_joined(room_id, self._server_name):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
|
|
@ -141,6 +141,7 @@ class FederationHandler:
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.is_mine_server_name = hs.is_mine_server_name
|
||||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
|
@ -453,7 +454,7 @@ class FederationHandler:
|
||||||
|
|
||||||
for dom in domains:
|
for dom in domains:
|
||||||
# We don't want to ask our own server for information we don't have
|
# We don't want to ask our own server for information we don't have
|
||||||
if dom == self.server_name:
|
if self.is_mine_server_name(dom):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -163,6 +163,7 @@ class FederationEventHandler:
|
||||||
self._notifier = hs.get_notifier()
|
self._notifier = hs.get_notifier()
|
||||||
|
|
||||||
self._is_mine_id = hs.is_mine_id
|
self._is_mine_id = hs.is_mine_id
|
||||||
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
|
@ -688,7 +689,7 @@ class FederationEventHandler:
|
||||||
server from invalid events (there is probably no point in trying to
|
server from invalid events (there is probably no point in trying to
|
||||||
re-fetch invalid events from every other HS in the room.)
|
re-fetch invalid events from every other HS in the room.)
|
||||||
"""
|
"""
|
||||||
if dest == self._server_name:
|
if self._is_mine_server_name(dest):
|
||||||
raise SynapseError(400, "Can't backfill from self.")
|
raise SynapseError(400, "Can't backfill from self.")
|
||||||
|
|
||||||
events = await self._federation_client.backfill(
|
events = await self._federation_client.backfill(
|
||||||
|
|
|
@ -59,7 +59,7 @@ class ProfileHandler:
|
||||||
self.max_avatar_size = hs.config.server.max_avatar_size
|
self.max_avatar_size = hs.config.server.max_avatar_size
|
||||||
self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
|
self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
|
||||||
|
|
||||||
self.server_name = hs.config.server.server_name
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
|
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
|
||||||
|
|
||||||
|
@ -309,7 +309,7 @@ class ProfileHandler:
|
||||||
else:
|
else:
|
||||||
server_name = host
|
server_name = host
|
||||||
|
|
||||||
if server_name == self.server_name:
|
if self._is_mine_server_name(server_name):
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.store.get_local_media(media_id)
|
||||||
else:
|
else:
|
||||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||||
|
|
|
@ -194,6 +194,7 @@ class SsoHandler:
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._device_handler = hs.get_device_handler()
|
self._device_handler = hs.get_device_handler()
|
||||||
|
@ -802,7 +803,7 @@ class SsoHandler:
|
||||||
if profile["avatar_url"] is not None:
|
if profile["avatar_url"] is not None:
|
||||||
server_name = profile["avatar_url"].split("/")[-2]
|
server_name = profile["avatar_url"].split("/")[-2]
|
||||||
media_id = profile["avatar_url"].split("/")[-1]
|
media_id = profile["avatar_url"].split("/")[-1]
|
||||||
if server_name == self._server_name:
|
if self._is_mine_server_name(server_name):
|
||||||
media = await self._media_repo.store.get_local_media(media_id)
|
media = await self._media_repo.store.get_local_media(media_id)
|
||||||
if media is not None and upload_name == media["upload_name"]:
|
if media is not None and upload_name == media["upload_name"]:
|
||||||
logger.info("skipping saving the user avatar")
|
logger.info("skipping saving the user avatar")
|
||||||
|
|
|
@ -68,6 +68,7 @@ class FollowerTypingHandler:
|
||||||
self.server_name = hs.config.server.server_name
|
self.server_name = hs.config.server.server_name
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
self.federation = None
|
self.federation = None
|
||||||
if hs.should_send_federation():
|
if hs.should_send_federation():
|
||||||
|
@ -153,7 +154,7 @@ class FollowerTypingHandler:
|
||||||
member.room_id
|
member.room_id
|
||||||
)
|
)
|
||||||
for domain in hosts:
|
for domain in hosts:
|
||||||
if domain != self.server_name:
|
if not self.is_mine_server_name(domain):
|
||||||
logger.debug("sending typing update to %s", domain)
|
logger.debug("sending typing update to %s", domain)
|
||||||
self.federation.build_and_send_edu(
|
self.federation.build_and_send_edu(
|
||||||
destination=domain,
|
destination=domain,
|
||||||
|
|
|
@ -258,7 +258,7 @@ class DeleteMediaByID(RestServlet):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.server_name = hs.hostname
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
|
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
|
@ -266,7 +266,7 @@ class DeleteMediaByID(RestServlet):
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
if self.server_name != server_name:
|
if not self._is_mine_server_name(server_name):
|
||||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
|
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
|
||||||
|
|
||||||
if await self.store.get_local_media(media_id) is None:
|
if await self.store.get_local_media(media_id) is None:
|
||||||
|
|
|
@ -501,7 +501,7 @@ class PublicRoomListRestServlet(RestServlet):
|
||||||
limit = None
|
limit = None
|
||||||
|
|
||||||
handler = self.hs.get_room_list_handler()
|
handler = self.hs.get_room_list_handler()
|
||||||
if server and server != self.hs.config.server.server_name:
|
if server and not self.hs.is_mine_server_name(server):
|
||||||
# Ensure the server is valid.
|
# Ensure the server is valid.
|
||||||
try:
|
try:
|
||||||
parse_and_validate_server_name(server)
|
parse_and_validate_server_name(server)
|
||||||
|
@ -551,7 +551,7 @@ class PublicRoomListRestServlet(RestServlet):
|
||||||
limit = None
|
limit = None
|
||||||
|
|
||||||
handler = self.hs.get_room_list_handler()
|
handler = self.hs.get_room_list_handler()
|
||||||
if server and server != self.hs.config.server.server_name:
|
if server and not self.hs.is_mine_server_name(server):
|
||||||
# Ensure the server is valid.
|
# Ensure the server is valid.
|
||||||
try:
|
try:
|
||||||
parse_and_validate_server_name(server)
|
parse_and_validate_server_name(server)
|
||||||
|
|
|
@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource):
|
||||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.server_name = hs.hostname
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
|
@ -59,7 +59,7 @@ class DownloadResource(DirectServeJsonResource):
|
||||||
b"no-referrer",
|
b"no-referrer",
|
||||||
)
|
)
|
||||||
server_name, media_id, name = parse_media_id(request)
|
server_name, media_id, name = parse_media_id(request)
|
||||||
if server_name == self.server_name:
|
if self._is_mine_server_name(server_name):
|
||||||
await self.media_repo.get_local_media(request, media_id, name)
|
await self.media_repo.get_local_media(request, media_id, name)
|
||||||
else:
|
else:
|
||||||
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
||||||
|
|
|
@ -59,7 +59,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.media_storage = media_storage
|
self.media_storage = media_storage
|
||||||
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
|
||||||
self.server_name = hs.hostname
|
self._is_mine_server_name = hs.is_mine_server_name
|
||||||
|
|
||||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
|
@ -71,7 +71,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
|
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
|
||||||
m_type = "image/png"
|
m_type = "image/png"
|
||||||
|
|
||||||
if server_name == self.server_name:
|
if self._is_mine_server_name(server_name):
|
||||||
if self.dynamic_thumbnails:
|
if self.dynamic_thumbnails:
|
||||||
await self._select_or_generate_local_thumbnail(
|
await self._select_or_generate_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type
|
request, media_id, width, height, method, m_type
|
||||||
|
|
|
@ -377,6 +377,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
return False
|
return False
|
||||||
return localpart_hostname[1] == self.hostname
|
return localpart_hostname[1] == self.hostname
|
||||||
|
|
||||||
|
def is_mine_server_name(self, server_name: str) -> bool:
|
||||||
|
"""Determines whether a server name refers to this homeserver."""
|
||||||
|
return server_name == self.hostname
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_clock(self) -> Clock:
|
def get_clock(self) -> Clock:
|
||||||
return Clock(self._reactor)
|
return Clock(self._reactor)
|
||||||
|
|
|
@ -996,7 +996,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
If it is `None` media will be removed from quarantine
|
If it is `None` media will be removed from quarantine
|
||||||
"""
|
"""
|
||||||
logger.info("Quarantining media: %s/%s", server_name, media_id)
|
logger.info("Quarantining media: %s/%s", server_name, media_id)
|
||||||
is_local = server_name == self.config.server.server_name
|
is_local = self.hs.is_mine_server_name(server_name)
|
||||||
|
|
||||||
def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
|
def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
|
||||||
local_mxcs = [media_id] if is_local else []
|
local_mxcs = [media_id] if is_local else []
|
||||||
|
|
|
@ -566,7 +566,9 @@ class HomeserverTestCase(TestCase):
|
||||||
client_ip,
|
client_ip,
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
|
def setup_test_homeserver(
|
||||||
|
self, name: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> HomeServer:
|
||||||
"""
|
"""
|
||||||
Set up the test homeserver, meant to be called by the overridable
|
Set up the test homeserver, meant to be called by the overridable
|
||||||
make_homeserver. It automatically passes through the test class's
|
make_homeserver. It automatically passes through the test class's
|
||||||
|
@ -585,15 +587,25 @@ class HomeserverTestCase(TestCase):
|
||||||
else:
|
else:
|
||||||
config = kwargs["config"]
|
config = kwargs["config"]
|
||||||
|
|
||||||
|
# The server name can be specified using either the `name` argument or a config
|
||||||
|
# override. The `name` argument takes precedence over any config overrides.
|
||||||
|
if name is not None:
|
||||||
|
config["server_name"] = name
|
||||||
|
|
||||||
# Parse the config from a config dict into a HomeServerConfig
|
# Parse the config from a config dict into a HomeServerConfig
|
||||||
config_obj = make_homeserver_config_obj(config)
|
config_obj = make_homeserver_config_obj(config)
|
||||||
kwargs["config"] = config_obj
|
kwargs["config"] = config_obj
|
||||||
|
|
||||||
|
# The server name in the config is now `name`, if provided, or the `server_name`
|
||||||
|
# from a config override, or the default of "test". Whichever it is, we
|
||||||
|
# construct a homeserver with a matching name.
|
||||||
|
kwargs["name"] = config_obj.server.server_name
|
||||||
|
|
||||||
async def run_bg_updates() -> None:
|
async def run_bg_updates() -> None:
|
||||||
with LoggingContext("run_bg_updates"):
|
with LoggingContext("run_bg_updates"):
|
||||||
self.get_success(stor.db_pool.updates.run_background_updates(False))
|
self.get_success(stor.db_pool.updates.run_background_updates(False))
|
||||||
|
|
||||||
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
hs = setup_test_homeserver(self.addCleanup, **kwargs)
|
||||||
stor = hs.get_datastores().main
|
stor = hs.get_datastores().main
|
||||||
|
|
||||||
# Run the database background updates, when running against "master".
|
# Run the database background updates, when running against "master".
|
||||||
|
|
Loading…
Reference in New Issue