Various improvements to the federation client. (#9129)
* Type hints for `FederationClient`. * Using `async` functions instead of returning `Awaitable` instances.
This commit is contained in:
parent
a5b9c87ac6
commit
620ecf13b0
|
@ -0,0 +1 @@
|
|||
Various improvements to the federation client.
|
|
@ -18,6 +18,7 @@ import copy
|
|||
import itertools
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
|
@ -26,7 +27,6 @@ from typing import (
|
|||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
|
|||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
|
||||
|
@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
|
|||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.pdu_destination_tried = {}
|
||||
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
|
||||
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
||||
self.state = hs.get_state_handler()
|
||||
self.transport_layer = hs.get_federation_transport_client()
|
||||
|
@ -116,33 +119,32 @@ class FederationClient(FederationBase):
|
|||
self.pdu_destination_tried[event_id] = destination_dict
|
||||
|
||||
@log_function
|
||||
def make_query(
|
||||
async def make_query(
|
||||
self,
|
||||
destination,
|
||||
query_type,
|
||||
args,
|
||||
retry_on_dns_fail=False,
|
||||
ignore_backoff=False,
|
||||
):
|
||||
destination: str,
|
||||
query_type: str,
|
||||
args: dict,
|
||||
retry_on_dns_fail: bool = False,
|
||||
ignore_backoff: bool = False,
|
||||
) -> JsonDict:
|
||||
"""Sends a federation Query to a remote homeserver of the given type
|
||||
and arguments.
|
||||
|
||||
Args:
|
||||
destination (str): Domain name of the remote homeserver
|
||||
query_type (str): Category of the query type; should match the
|
||||
destination: Domain name of the remote homeserver
|
||||
query_type: Category of the query type; should match the
|
||||
handler name used in register_query_handler().
|
||||
args (dict): Mapping of strings to strings containing the details
|
||||
args: Mapping of strings to strings containing the details
|
||||
of the query request.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
ignore_backoff: true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
|
||||
Returns:
|
||||
a Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
The JSON object from the response
|
||||
"""
|
||||
sent_queries_counter.labels(query_type).inc()
|
||||
|
||||
return self.transport_layer.make_query(
|
||||
return await self.transport_layer.make_query(
|
||||
destination,
|
||||
query_type,
|
||||
args,
|
||||
|
@ -151,42 +153,52 @@ class FederationClient(FederationBase):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def query_client_keys(self, destination, content, timeout):
|
||||
async def query_client_keys(
|
||||
self, destination: str, content: JsonDict, timeout: int
|
||||
) -> JsonDict:
|
||||
"""Query device keys for a device hosted on a remote server.
|
||||
|
||||
Args:
|
||||
destination (str): Domain name of the remote homeserver
|
||||
content (dict): The query content.
|
||||
destination: Domain name of the remote homeserver
|
||||
content: The query content.
|
||||
|
||||
Returns:
|
||||
an Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
The JSON object from the response
|
||||
"""
|
||||
sent_queries_counter.labels("client_device_keys").inc()
|
||||
return self.transport_layer.query_client_keys(destination, content, timeout)
|
||||
return await self.transport_layer.query_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
|
||||
@log_function
|
||||
def query_user_devices(self, destination, user_id, timeout=30000):
|
||||
async def query_user_devices(
|
||||
self, destination: str, user_id: str, timeout: int = 30000
|
||||
) -> JsonDict:
|
||||
"""Query the device keys for a list of user ids hosted on a remote
|
||||
server.
|
||||
"""
|
||||
sent_queries_counter.labels("user_devices").inc()
|
||||
return self.transport_layer.query_user_devices(destination, user_id, timeout)
|
||||
return await self.transport_layer.query_user_devices(
|
||||
destination, user_id, timeout
|
||||
)
|
||||
|
||||
@log_function
|
||||
def claim_client_keys(self, destination, content, timeout):
|
||||
async def claim_client_keys(
|
||||
self, destination: str, content: JsonDict, timeout: int
|
||||
) -> JsonDict:
|
||||
"""Claims one-time keys for a device hosted on a remote server.
|
||||
|
||||
Args:
|
||||
destination (str): Domain name of the remote homeserver
|
||||
content (dict): The query content.
|
||||
destination: Domain name of the remote homeserver
|
||||
content: The query content.
|
||||
|
||||
Returns:
|
||||
an Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
The JSON object from the response
|
||||
"""
|
||||
sent_queries_counter.labels("client_one_time_keys").inc()
|
||||
return self.transport_layer.claim_client_keys(destination, content, timeout)
|
||||
return await self.transport_layer.claim_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
|
||||
async def backfill(
|
||||
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
||||
|
@ -195,10 +207,10 @@ class FederationClient(FederationBase):
|
|||
given destination server.
|
||||
|
||||
Args:
|
||||
dest (str): The remote homeserver to ask.
|
||||
room_id (str): The room_id to backfill.
|
||||
limit (int): The maximum number of events to return.
|
||||
extremities (list): our current backwards extremities, to backfill from
|
||||
dest: The remote homeserver to ask.
|
||||
room_id: The room_id to backfill.
|
||||
limit: The maximum number of events to return.
|
||||
extremities: our current backwards extremities, to backfill from
|
||||
"""
|
||||
logger.debug("backfill extrem=%s", extremities)
|
||||
|
||||
|
@ -370,7 +382,7 @@ class FederationClient(FederationBase):
|
|||
for events that have failed their checks
|
||||
|
||||
Returns:
|
||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||
A list of PDUs that have valid signatures and hashes.
|
||||
"""
|
||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
|
||||
|
@ -418,7 +430,9 @@ class FederationClient(FederationBase):
|
|||
else:
|
||||
return [p for p in valid_pdus if p]
|
||||
|
||||
async def get_event_auth(self, destination, room_id, event_id):
|
||||
async def get_event_auth(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
) -> List[EventBase]:
|
||||
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
@ -700,18 +714,16 @@ class FederationClient(FederationBase):
|
|||
|
||||
return await self._try_destination_list("send_join", destinations, send_request)
|
||||
|
||||
async def _do_send_join(self, destination: str, pdu: EventBase):
|
||||
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
content = await self.transport_layer.send_join_v2(
|
||||
return await self.transport_layer.send_join_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
|
@ -769,7 +781,7 @@ class FederationClient(FederationBase):
|
|||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
content = await self.transport_layer.send_invite_v2(
|
||||
return await self.transport_layer.send_invite_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
|
@ -779,7 +791,6 @@ class FederationClient(FederationBase):
|
|||
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
||||
},
|
||||
)
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
|
@ -842,18 +853,16 @@ class FederationClient(FederationBase):
|
|||
"send_leave", destinations, send_request
|
||||
)
|
||||
|
||||
async def _do_send_leave(self, destination, pdu):
|
||||
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
content = await self.transport_layer.send_leave_v2(
|
||||
return await self.transport_layer.send_leave_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
|
@ -879,7 +888,7 @@ class FederationClient(FederationBase):
|
|||
# content.
|
||||
return resp[1]
|
||||
|
||||
def get_public_rooms(
|
||||
async def get_public_rooms(
|
||||
self,
|
||||
remote_server: str,
|
||||
limit: Optional[int] = None,
|
||||
|
@ -887,7 +896,7 @@ class FederationClient(FederationBase):
|
|||
search_filter: Optional[Dict] = None,
|
||||
include_all_networks: bool = False,
|
||||
third_party_instance_id: Optional[str] = None,
|
||||
):
|
||||
) -> JsonDict:
|
||||
"""Get the list of public rooms from a remote homeserver
|
||||
|
||||
Args:
|
||||
|
@ -901,8 +910,7 @@ class FederationClient(FederationBase):
|
|||
party instance
|
||||
|
||||
Returns:
|
||||
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
|
||||
`remote_server` is the same as the local server_name
|
||||
The response from the remote server.
|
||||
|
||||
Raises:
|
||||
HttpResponseException: There was an exception returned from the remote server
|
||||
|
@ -910,7 +918,7 @@ class FederationClient(FederationBase):
|
|||
requests over federation
|
||||
|
||||
"""
|
||||
return self.transport_layer.get_public_rooms(
|
||||
return await self.transport_layer.get_public_rooms(
|
||||
remote_server,
|
||||
limit,
|
||||
since_token,
|
||||
|
@ -923,7 +931,7 @@ class FederationClient(FederationBase):
|
|||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
earliest_events_ids: Sequence[str],
|
||||
earliest_events_ids: Iterable[str],
|
||||
latest_events: Iterable[EventBase],
|
||||
limit: int,
|
||||
min_depth: int,
|
||||
|
@ -974,7 +982,9 @@ class FederationClient(FederationBase):
|
|||
|
||||
return signed_events
|
||||
|
||||
async def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||
async def forward_third_party_invite(
|
||||
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
||||
) -> None:
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
|
@ -983,7 +993,7 @@ class FederationClient(FederationBase):
|
|||
await self.transport_layer.exchange_third_party_invite(
|
||||
destination=destination, room_id=room_id, event_dict=event_dict
|
||||
)
|
||||
return None
|
||||
return
|
||||
except CodeMessageException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
async def get_room_complexity(
|
||||
self, destination: str, room_id: str
|
||||
) -> Optional[dict]:
|
||||
) -> Optional[JsonDict]:
|
||||
"""
|
||||
Fetch the complexity of a remote room from another server.
|
||||
|
||||
|
@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
|
|||
could not fetch the complexity.
|
||||
"""
|
||||
try:
|
||||
complexity = await self.transport_layer.get_room_complexity(
|
||||
return await self.transport_layer.get_room_complexity(
|
||||
destination=destination, room_id=room_id
|
||||
)
|
||||
return complexity
|
||||
except CodeMessageException as e:
|
||||
# We didn't manage to get it -- probably a 404. We are okay if other
|
||||
# servers don't give it to us.
|
||||
|
|
Loading…
Reference in New Issue