Batch up replication requests to request the resyncing of remote users's devices. (#14716)
This commit is contained in:
parent
3479599387
commit
ba4ea7d13f
|
@ -0,0 +1 @@
|
||||||
|
Batch up replication requests to request the resyncing of remote users's devices.
|
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
@ -33,6 +34,7 @@ from synapse.api.errors import (
|
||||||
Codes,
|
Codes,
|
||||||
FederationDeniedError,
|
FederationDeniedError,
|
||||||
HttpResponseException,
|
HttpResponseException,
|
||||||
|
InvalidAPICallError,
|
||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
@ -45,6 +47,7 @@ from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
StreamKeyType,
|
StreamKeyType,
|
||||||
StreamToken,
|
StreamToken,
|
||||||
|
UserID,
|
||||||
get_domain_from_id,
|
get_domain_from_id,
|
||||||
get_verify_key_from_cross_signing_key,
|
get_verify_key_from_cross_signing_key,
|
||||||
)
|
)
|
||||||
|
@ -893,12 +896,47 @@ class DeviceListWorkerUpdater:
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
from synapse.replication.http.devices import (
|
from synapse.replication.http.devices import (
|
||||||
|
ReplicationMultiUserDevicesResyncRestServlet,
|
||||||
ReplicationUserDevicesResyncRestServlet,
|
ReplicationUserDevicesResyncRestServlet,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._user_device_resync_client = (
|
self._user_device_resync_client = (
|
||||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||||
)
|
)
|
||||||
|
self._multi_user_device_resync_client = (
|
||||||
|
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def multi_user_device_resync(
|
||||||
|
self, user_ids: List[str], mark_failed_as_stale: bool = True
|
||||||
|
) -> Dict[str, Optional[JsonDict]]:
|
||||||
|
"""
|
||||||
|
Like `user_device_resync` but operates on multiple users **from the same origin**
|
||||||
|
at once.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict from User ID to the same Dict as `user_device_resync`.
|
||||||
|
"""
|
||||||
|
# mark_failed_as_stale is not sent. Ensure this doesn't break expectations.
|
||||||
|
assert mark_failed_as_stale
|
||||||
|
|
||||||
|
if not user_ids:
|
||||||
|
# Shortcut empty requests
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._multi_user_device_resync_client(user_ids=user_ids)
|
||||||
|
except SynapseError as err:
|
||||||
|
if not (
|
||||||
|
err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
|
||||||
|
):
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Fall back to single requests
|
||||||
|
result: Dict[str, Optional[JsonDict]] = {}
|
||||||
|
for user_id in user_ids:
|
||||||
|
result[user_id] = await self._user_device_resync_client(user_id=user_id)
|
||||||
|
return result
|
||||||
|
|
||||||
async def user_device_resync(
|
async def user_device_resync(
|
||||||
self, user_id: str, mark_failed_as_stale: bool = True
|
self, user_id: str, mark_failed_as_stale: bool = True
|
||||||
|
@ -913,8 +951,10 @@ class DeviceListWorkerUpdater:
|
||||||
A dict with device info as under the "devices" in the result of this
|
A dict with device info as under the "devices" in the result of this
|
||||||
request:
|
request:
|
||||||
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
|
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
|
||||||
|
None when we weren't able to fetch the device info for some reason,
|
||||||
|
e.g. due to a connection problem.
|
||||||
"""
|
"""
|
||||||
return await self._user_device_resync_client(user_id=user_id)
|
return (await self.multi_user_device_resync([user_id]))[user_id]
|
||||||
|
|
||||||
|
|
||||||
class DeviceListUpdater(DeviceListWorkerUpdater):
|
class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||||
|
@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||||
# Allow future calls to retry resyncinc out of sync device lists.
|
# Allow future calls to retry resyncinc out of sync device lists.
|
||||||
self._resync_retry_in_progress = False
|
self._resync_retry_in_progress = False
|
||||||
|
|
||||||
|
async def multi_user_device_resync(
|
||||||
|
self, user_ids: List[str], mark_failed_as_stale: bool = True
|
||||||
|
) -> Dict[str, Optional[JsonDict]]:
|
||||||
|
"""
|
||||||
|
Like `user_device_resync` but operates on multiple users **from the same origin**
|
||||||
|
at once.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict from User ID to the same Dict as `user_device_resync`.
|
||||||
|
"""
|
||||||
|
if not user_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
origins = {UserID.from_string(user_id).domain for user_id in user_ids}
|
||||||
|
|
||||||
|
if len(origins) != 1:
|
||||||
|
raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}")
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
failed = set()
|
||||||
|
# TODO(Perf): Actually batch these up
|
||||||
|
for user_id in user_ids:
|
||||||
|
user_result, user_failed = await self._user_device_resync_returning_failed(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
result[user_id] = user_result
|
||||||
|
if user_failed:
|
||||||
|
failed.add(user_id)
|
||||||
|
|
||||||
|
if mark_failed_as_stale:
|
||||||
|
await self.store.mark_remote_users_device_caches_as_stale(failed)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def user_device_resync(
|
async def user_device_resync(
|
||||||
self, user_id: str, mark_failed_as_stale: bool = True
|
self, user_id: str, mark_failed_as_stale: bool = True
|
||||||
) -> Optional[JsonDict]:
|
) -> Optional[JsonDict]:
|
||||||
|
result, failed = await self._user_device_resync_returning_failed(user_id)
|
||||||
|
|
||||||
|
if failed and mark_failed_as_stale:
|
||||||
|
# Mark the remote user's device list as stale so we know we need to retry
|
||||||
|
# it later.
|
||||||
|
await self.store.mark_remote_users_device_caches_as_stale((user_id,))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _user_device_resync_returning_failed(
|
||||||
|
self, user_id: str
|
||||||
|
) -> Tuple[Optional[JsonDict], bool]:
|
||||||
"""Fetches all devices for a user and updates the device cache with them.
|
"""Fetches all devices for a user and updates the device cache with them.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user's id whose device_list will be updated.
|
user_id: The user's id whose device_list will be updated.
|
||||||
mark_failed_as_stale: Whether to mark the user's device list as stale
|
|
||||||
if the attempt to resync failed.
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict with device info as under the "devices" in the result of this
|
- A dict with device info as under the "devices" in the result of this
|
||||||
request:
|
request:
|
||||||
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
|
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
|
||||||
|
None when we weren't able to fetch the device info for some reason,
|
||||||
|
e.g. due to a connection problem.
|
||||||
|
- True iff the resync failed and the device list should be marked as stale.
|
||||||
"""
|
"""
|
||||||
logger.debug("Attempting to resync the device list for %s", user_id)
|
logger.debug("Attempting to resync the device list for %s", user_id)
|
||||||
log_kv({"message": "Doing resync to update device list."})
|
log_kv({"message": "Doing resync to update device list."})
|
||||||
|
@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||||
try:
|
try:
|
||||||
result = await self.federation.query_user_devices(origin, user_id)
|
result = await self.federation.query_user_devices(origin, user_id)
|
||||||
except NotRetryingDestination:
|
except NotRetryingDestination:
|
||||||
if mark_failed_as_stale:
|
return None, True
|
||||||
# Mark the remote user's device list as stale so we know we need to retry
|
|
||||||
# it later.
|
|
||||||
await self.store.mark_remote_user_device_cache_as_stale(user_id)
|
|
||||||
|
|
||||||
return None
|
|
||||||
except (RequestSendFailed, HttpResponseException) as e:
|
except (RequestSendFailed, HttpResponseException) as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to handle device list update for %s: %s",
|
"Failed to handle device list update for %s: %s",
|
||||||
|
@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mark_failed_as_stale:
|
|
||||||
# Mark the remote user's device list as stale so we know we need to retry
|
|
||||||
# it later.
|
|
||||||
await self.store.mark_remote_user_device_cache_as_stale(user_id)
|
|
||||||
|
|
||||||
# We abort on exceptions rather than accepting the update
|
# We abort on exceptions rather than accepting the update
|
||||||
# as otherwise synapse will 'forget' that its device list
|
# as otherwise synapse will 'forget' that its device list
|
||||||
# is out of date. If we bail then we will retry the resync
|
# is out of date. If we bail then we will retry the resync
|
||||||
# next time we get a device list update for this user_id.
|
# next time we get a device list update for this user_id.
|
||||||
# This makes it more likely that the device lists will
|
# This makes it more likely that the device lists will
|
||||||
# eventually become consistent.
|
# eventually become consistent.
|
||||||
return None
|
return None, True
|
||||||
except FederationDeniedError as e:
|
except FederationDeniedError as e:
|
||||||
set_tag("error", True)
|
set_tag("error", True)
|
||||||
log_kv({"reason": "FederationDeniedError"})
|
log_kv({"reason": "FederationDeniedError"})
|
||||||
logger.info(e)
|
logger.info(e)
|
||||||
return None
|
return None, False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
set_tag("error", True)
|
set_tag("error", True)
|
||||||
log_kv(
|
log_kv(
|
||||||
|
@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||||
)
|
)
|
||||||
logger.exception("Failed to handle device list update for %s", user_id)
|
logger.exception("Failed to handle device list update for %s", user_id)
|
||||||
|
|
||||||
if mark_failed_as_stale:
|
return None, True
|
||||||
# Mark the remote user's device list as stale so we know we need to retry
|
|
||||||
# it later.
|
|
||||||
await self.store.mark_remote_user_device_cache_as_stale(user_id)
|
|
||||||
|
|
||||||
return None
|
|
||||||
log_kv({"result": result})
|
log_kv({"result": result})
|
||||||
stream_id = result["stream_id"]
|
stream_id = result["stream_id"]
|
||||||
devices = result["devices"]
|
devices = result["devices"]
|
||||||
|
@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||||
# point.
|
# point.
|
||||||
self._seen_updates[user_id] = {stream_id}
|
self._seen_updates[user_id] = {stream_id}
|
||||||
|
|
||||||
return result
|
return result, False
|
||||||
|
|
||||||
async def process_cross_signing_key_update(
|
async def process_cross_signing_key_update(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -195,7 +195,7 @@ class DeviceMessageHandler:
|
||||||
sender_user_id,
|
sender_user_id,
|
||||||
unknown_devices,
|
unknown_devices,
|
||||||
)
|
)
|
||||||
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
|
await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
|
||||||
|
|
||||||
# Immediately attempt a resync in the background
|
# Immediately attempt a resync in the background
|
||||||
run_in_background(self._user_device_resync, user_id=sender_user_id)
|
run_in_background(self._user_device_resync, user_id=sender_user_id)
|
||||||
|
|
|
@ -36,8 +36,8 @@ from synapse.types import (
|
||||||
get_domain_from_id,
|
get_domain_from_id,
|
||||||
get_verify_key_from_cross_signing_key,
|
get_verify_key_from_cross_signing_key,
|
||||||
)
|
)
|
||||||
from synapse.util import json_decoder, unwrapFirstError
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import Linearizer, delay_cancellation
|
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||||
from synapse.util.cancellation import cancellable
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
|
@ -238,24 +238,28 @@ class E2eKeysHandler:
|
||||||
# Now fetch any devices that we don't have in our cache
|
# Now fetch any devices that we don't have in our cache
|
||||||
# TODO It might make sense to propagate cancellations into the
|
# TODO It might make sense to propagate cancellations into the
|
||||||
# deferreds which are querying remote homeservers.
|
# deferreds which are querying remote homeservers.
|
||||||
await make_deferred_yieldable(
|
logger.debug(
|
||||||
delay_cancellation(
|
"%d destinations to query devices for", len(remote_queries_not_in_cache)
|
||||||
defer.gatherResults(
|
)
|
||||||
[
|
|
||||||
run_in_background(
|
async def _query(
|
||||||
self._query_devices_for_destination,
|
destination_queries: Tuple[str, Dict[str, Iterable[str]]]
|
||||||
results,
|
) -> None:
|
||||||
cross_signing_keys,
|
destination, queries = destination_queries
|
||||||
failures,
|
return await self._query_devices_for_destination(
|
||||||
destination,
|
results,
|
||||||
queries,
|
cross_signing_keys,
|
||||||
timeout,
|
failures,
|
||||||
)
|
destination,
|
||||||
for destination, queries in remote_queries_not_in_cache.items()
|
queries,
|
||||||
],
|
timeout,
|
||||||
consumeErrors=True,
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await concurrently_execute(
|
||||||
|
_query,
|
||||||
|
remote_queries_not_in_cache.items(),
|
||||||
|
10,
|
||||||
|
delay_cancellation=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = {"device_keys": results, "failures": failures}
|
ret = {"device_keys": results, "failures": failures}
|
||||||
|
@ -300,28 +304,41 @@ class E2eKeysHandler:
|
||||||
# queries. We use the more efficient batched query_client_keys for all
|
# queries. We use the more efficient batched query_client_keys for all
|
||||||
# remaining users
|
# remaining users
|
||||||
user_ids_updated = []
|
user_ids_updated = []
|
||||||
for (user_id, device_list) in destination_query.items():
|
|
||||||
if user_id in user_ids_updated:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if device_list:
|
# Perform a user device resync for each user only once and only as long as:
|
||||||
continue
|
# - they have an empty device_list
|
||||||
|
# - they are in some rooms that this server can see
|
||||||
|
users_to_resync_devices = {
|
||||||
|
user_id
|
||||||
|
for (user_id, device_list) in destination_query.items()
|
||||||
|
if (not device_list) and (await self.store.get_rooms_for_user(user_id))
|
||||||
|
}
|
||||||
|
|
||||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
logger.debug(
|
||||||
if not room_ids:
|
"%d users to resync devices for from destination %s",
|
||||||
continue
|
len(users_to_resync_devices),
|
||||||
|
destination,
|
||||||
|
)
|
||||||
|
|
||||||
# We've decided we're sharing a room with this user and should
|
try:
|
||||||
# probably be tracking their device lists. However, we haven't
|
user_resync_results = (
|
||||||
# done an initial sync on the device list so we do it now.
|
await self.device_handler.device_list_updater.multi_user_device_resync(
|
||||||
try:
|
list(users_to_resync_devices)
|
||||||
resync_results = (
|
|
||||||
await self.device_handler.device_list_updater.user_device_resync(
|
|
||||||
user_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
for user_id in users_to_resync_devices:
|
||||||
|
resync_results = user_resync_results[user_id]
|
||||||
|
|
||||||
if resync_results is None:
|
if resync_results is None:
|
||||||
raise ValueError("Device resync failed")
|
# TODO: It's weird that we'll store a failure against a
|
||||||
|
# destination, yet continue processing users from that
|
||||||
|
# destination.
|
||||||
|
# We might want to consider changing this, but for now
|
||||||
|
# I'm leaving it as I found it.
|
||||||
|
failures[destination] = _exception_to_failure(
|
||||||
|
ValueError(f"Device resync failed for {user_id!r}")
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Add the device keys to the results.
|
# Add the device keys to the results.
|
||||||
user_devices = resync_results["devices"]
|
user_devices = resync_results["devices"]
|
||||||
|
@ -339,8 +356,8 @@ class E2eKeysHandler:
|
||||||
|
|
||||||
if self_signing_key:
|
if self_signing_key:
|
||||||
cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
|
cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failures[destination] = _exception_to_failure(e)
|
failures[destination] = _exception_to_failure(e)
|
||||||
|
|
||||||
if len(destination_query) == len(user_ids_updated):
|
if len(destination_query) == len(user_ids_updated):
|
||||||
# We've updated all the users in the query and we do not need to
|
# We've updated all the users in the query and we do not need to
|
||||||
|
|
|
@ -1423,7 +1423,7 @@ class FederationEventHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._store.mark_remote_user_device_cache_as_stale(sender)
|
await self._store.mark_remote_users_device_caches_as_stale((sender,))
|
||||||
|
|
||||||
# Immediately attempt a resync in the background
|
# Immediately attempt a resync in the background
|
||||||
if self._config.worker.worker_app:
|
if self._config.worker.worker_app:
|
||||||
|
|
|
@ -13,12 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.logging.opentracing import active_span
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -84,6 +85,76 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||||
return 200, user_devices
|
return 200, user_devices
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||||
|
"""Ask master to resync the device list for multiple users from the same
|
||||||
|
remote server by contacting their server.
|
||||||
|
|
||||||
|
This must happen on master so that the results can be correctly cached in
|
||||||
|
the database and streamed to workers.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/multi_user_device_resync
|
||||||
|
|
||||||
|
{
|
||||||
|
"user_ids": ["@alice:example.org", "@bob:example.org", ...]
|
||||||
|
}
|
||||||
|
|
||||||
|
Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
|
||||||
|
response, but there is a map from user ID to response, e.g.:
|
||||||
|
|
||||||
|
{
|
||||||
|
"@alice:example.org": {
|
||||||
|
"devices": [
|
||||||
|
{
|
||||||
|
"device_id": "JLAFKJWSCS",
|
||||||
|
"keys": { ... },
|
||||||
|
"device_display_name": "Alice's Mobile Phone"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "multi_user_device_resync"
|
||||||
|
PATH_ARGS = ()
|
||||||
|
CACHE = False
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
|
from synapse.handlers.device import DeviceHandler
|
||||||
|
|
||||||
|
handler = hs.get_device_handler()
|
||||||
|
assert isinstance(handler, DeviceHandler)
|
||||||
|
self.device_list_updater = handler.device_list_updater
|
||||||
|
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[override]
|
||||||
|
return {"user_ids": user_ids}
|
||||||
|
|
||||||
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request
|
||||||
|
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
user_ids: List[str] = content["user_ids"]
|
||||||
|
|
||||||
|
logger.info("Resync for %r", user_ids)
|
||||||
|
span = active_span()
|
||||||
|
if span:
|
||||||
|
span.set_tag("user_ids", f"{user_ids!r}")
|
||||||
|
|
||||||
|
multi_user_devices = await self.device_list_updater.multi_user_device_resync(
|
||||||
|
user_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, multi_user_devices
|
||||||
|
|
||||||
|
|
||||||
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
|
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
|
||||||
"""Ask master to upload keys for the user and send them out over federation to
|
"""Ask master to upload keys for the user and send them out over federation to
|
||||||
update other servers.
|
update other servers.
|
||||||
|
@ -151,4 +222,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
|
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
|
||||||
|
ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
|
||||||
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
|
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import (
|
||||||
AbstractStreamIdTracker,
|
AbstractStreamIdTracker,
|
||||||
StreamIdGenerator,
|
StreamIdGenerator,
|
||||||
)
|
)
|
||||||
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
|
||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
@ -1069,16 +1069,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
return {row["user_id"] for row in rows}
|
return {row["user_id"] for row in rows}
|
||||||
|
|
||||||
async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
|
async def mark_remote_users_device_caches_as_stale(
|
||||||
|
self, user_ids: StrCollection
|
||||||
|
) -> None:
|
||||||
"""Records that the server has reason to believe the cache of the devices
|
"""Records that the server has reason to believe the cache of the devices
|
||||||
for the remote users is out of date.
|
for the remote users is out of date.
|
||||||
"""
|
"""
|
||||||
await self.db_pool.simple_upsert(
|
|
||||||
table="device_lists_remote_resync",
|
def _mark_remote_users_device_caches_as_stale_txn(
|
||||||
keyvalues={"user_id": user_id},
|
txn: LoggingTransaction,
|
||||||
values={},
|
) -> None:
|
||||||
insertion_values={"added_ts": self._clock.time_msec()},
|
# TODO add insertion_values support to simple_upsert_many and use
|
||||||
desc="mark_remote_user_device_cache_as_stale",
|
# that!
|
||||||
|
for user_id in user_ids:
|
||||||
|
self.db_pool.simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_resync",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
values={},
|
||||||
|
insertion_values={"added_ts": self._clock.time_msec()},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"mark_remote_users_device_caches_as_stale",
|
||||||
|
_mark_remote_users_device_caches_as_stale_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
|
async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
|
||||||
|
|
|
@ -77,6 +77,10 @@ JsonMapping = Mapping[str, Any]
|
||||||
# A JSON-serialisable object.
|
# A JSON-serialisable object.
|
||||||
JsonSerializable = object
|
JsonSerializable = object
|
||||||
|
|
||||||
|
# Collection[str] that does not include str itself; str being a Sequence[str]
|
||||||
|
# is very misleading and results in bugs.
|
||||||
|
StrCollection = Union[Tuple[str, ...], List[str], Set[str]]
|
||||||
|
|
||||||
|
|
||||||
# Note that this seems to require inheriting *directly* from Interface in order
|
# Note that this seems to require inheriting *directly* from Interface in order
|
||||||
# for mypy-zope to realize it is an interface.
|
# for mypy-zope to realize it is an interface.
|
||||||
|
|
|
@ -205,7 +205,10 @@ T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
async def concurrently_execute(
|
async def concurrently_execute(
|
||||||
func: Callable[[T], Any], args: Iterable[T], limit: int
|
func: Callable[[T], Any],
|
||||||
|
args: Iterable[T],
|
||||||
|
limit: int,
|
||||||
|
delay_cancellation: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Executes the function with each argument concurrently while limiting
|
"""Executes the function with each argument concurrently while limiting
|
||||||
the number of concurrent executions.
|
the number of concurrent executions.
|
||||||
|
@ -215,6 +218,8 @@ async def concurrently_execute(
|
||||||
args: List of arguments to pass to func, each invocation of func
|
args: List of arguments to pass to func, each invocation of func
|
||||||
gets a single argument.
|
gets a single argument.
|
||||||
limit: Maximum number of conccurent executions.
|
limit: Maximum number of conccurent executions.
|
||||||
|
delay_cancellation: Whether to delay cancellation until after the invocations
|
||||||
|
have finished.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None, when all function invocations have finished. The return values
|
None, when all function invocations have finished. The return values
|
||||||
|
@ -233,9 +238,16 @@ async def concurrently_execute(
|
||||||
# We use `itertools.islice` to handle the case where the number of args is
|
# We use `itertools.islice` to handle the case where the number of args is
|
||||||
# less than the limit, avoiding needlessly spawning unnecessary background
|
# less than the limit, avoiding needlessly spawning unnecessary background
|
||||||
# tasks.
|
# tasks.
|
||||||
await yieldable_gather_results(
|
if delay_cancellation:
|
||||||
_concurrently_execute_inner, (value for value in itertools.islice(it, limit))
|
await yieldable_gather_results_delaying_cancellation(
|
||||||
)
|
_concurrently_execute_inner,
|
||||||
|
(value for value in itertools.islice(it, limit)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await yieldable_gather_results(
|
||||||
|
_concurrently_execute_inner,
|
||||||
|
(value for value in itertools.islice(it, limit)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
@ -292,6 +304,41 @@ async def yieldable_gather_results(
|
||||||
raise dfe.subFailure.value from None
|
raise dfe.subFailure.value from None
|
||||||
|
|
||||||
|
|
||||||
|
async def yieldable_gather_results_delaying_cancellation(
|
||||||
|
func: Callable[Concatenate[T, P], Awaitable[R]],
|
||||||
|
iter: Iterable[T],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> List[R]:
|
||||||
|
"""Executes the function with each argument concurrently.
|
||||||
|
Cancellation is delayed until after all the results have been gathered.
|
||||||
|
|
||||||
|
See `yieldable_gather_results`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to execute that returns a Deferred
|
||||||
|
iter: An iterable that yields items that get passed as the first
|
||||||
|
argument to the function
|
||||||
|
*args: Arguments to be passed to each call to func
|
||||||
|
**kwargs: Keyword arguments to be passed to each call to func
|
||||||
|
|
||||||
|
Returns
|
||||||
|
A list containing the results of the function
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await make_deferred_yieldable(
|
||||||
|
delay_cancellation(
|
||||||
|
defer.gatherResults(
|
||||||
|
[run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
|
||||||
|
consumeErrors=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except defer.FirstError as dfe:
|
||||||
|
assert isinstance(dfe.subFailure.value, BaseException)
|
||||||
|
raise dfe.subFailure.value from None
|
||||||
|
|
||||||
|
|
||||||
T1 = TypeVar("T1")
|
T1 = TypeVar("T1")
|
||||||
T2 = TypeVar("T2")
|
T2 = TypeVar("T2")
|
||||||
T3 = TypeVar("T3")
|
T3 = TypeVar("T3")
|
||||||
|
|
Loading…
Reference in New Issue