Limit the number of in-flight /keys/query requests from a single device. (#10144)
This commit is contained in:
parent
1bf83a191b
commit
11846dff8c
|
@ -0,0 +1 @@
|
|||
Limit the number of in-flight `/keys/query` requests from a single device.
|
|
@ -79,9 +79,15 @@ class E2eKeysHandler:
|
|||
"client_keys", self.on_federation_query_client_keys
|
||||
)
|
||||
|
||||
# Limit the number of in-flight requests from a single device.
|
||||
self._query_devices_linearizer = Linearizer(
|
||||
name="query_devices",
|
||||
max_count=10,
|
||||
)
|
||||
|
||||
@trace
|
||||
async def query_devices(
|
||||
self, query_body: JsonDict, timeout: int, from_user_id: str
|
||||
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
|
||||
) -> JsonDict:
|
||||
"""Handle a device key query from a client
|
||||
|
||||
|
@ -105,191 +111,197 @@ class E2eKeysHandler:
|
|||
from_user_id: the user making the query. This is used when
|
||||
adding cross-signing signatures to limit what signatures users
|
||||
can see.
|
||||
from_device_id: the device making the query. This is used to limit
|
||||
the number of in-flight queries at a time.
|
||||
"""
|
||||
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||
device_keys_query = query_body.get(
|
||||
"device_keys", {}
|
||||
) # type: Dict[str, Iterable[str]]
|
||||
|
||||
device_keys_query = query_body.get(
|
||||
"device_keys", {}
|
||||
) # type: Dict[str, Iterable[str]]
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
local_query = {}
|
||||
remote_queries = {}
|
||||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
local_query = {}
|
||||
remote_queries = {}
|
||||
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
local_query[user_id] = device_ids
|
||||
else:
|
||||
remote_queries[user_id] = device_ids
|
||||
|
||||
set_tag("local_key_query", local_query)
|
||||
set_tag("remote_key_query", remote_queries)
|
||||
|
||||
# First get local devices.
|
||||
# A map of destination -> failure response.
|
||||
failures = {} # type: Dict[str, JsonDict]
|
||||
results = {}
|
||||
if local_query:
|
||||
local_result = await self.query_local_devices(local_query)
|
||||
for user_id, keys in local_result.items():
|
||||
if user_id in local_query:
|
||||
results[user_id] = keys
|
||||
|
||||
# Get cached cross-signing keys
|
||||
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
|
||||
device_keys_query, from_user_id
|
||||
)
|
||||
|
||||
# Now attempt to get any remote devices from our local cache.
|
||||
# A map of destination -> user ID -> device IDs.
|
||||
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
|
||||
if remote_queries:
|
||||
query_list = [] # type: List[Tuple[str, Optional[str]]]
|
||||
for user_id, device_ids in remote_queries.items():
|
||||
if device_ids:
|
||||
query_list.extend((user_id, device_id) for device_id in device_ids)
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
local_query[user_id] = device_ids
|
||||
else:
|
||||
query_list.append((user_id, None))
|
||||
remote_queries[user_id] = device_ids
|
||||
|
||||
(
|
||||
user_ids_not_in_cache,
|
||||
remote_results,
|
||||
) = await self.store.get_user_devices_from_cache(query_list)
|
||||
for user_id, devices in remote_results.items():
|
||||
user_devices = results.setdefault(user_id, {})
|
||||
for device_id, device in devices.items():
|
||||
keys = device.get("keys", None)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if keys:
|
||||
result = dict(keys)
|
||||
unsigned = result.setdefault("unsigned", {})
|
||||
if device_display_name:
|
||||
unsigned["device_display_name"] = device_display_name
|
||||
user_devices[device_id] = result
|
||||
set_tag("local_key_query", local_query)
|
||||
set_tag("remote_key_query", remote_queries)
|
||||
|
||||
# check for missing cross-signing keys.
|
||||
for user_id in remote_queries.keys():
|
||||
cached_cross_master = user_id in cross_signing_keys["master_keys"]
|
||||
cached_cross_selfsigning = (
|
||||
user_id in cross_signing_keys["self_signing_keys"]
|
||||
)
|
||||
|
||||
# check if we are missing only one of cross-signing master or
|
||||
# self-signing key, but the other one is cached.
|
||||
# as we need both, this will issue a federation request.
|
||||
# if we don't have any of the keys, either the user doesn't have
|
||||
# cross-signing set up, or the cached device list
|
||||
# is not (yet) updated.
|
||||
if cached_cross_master ^ cached_cross_selfsigning:
|
||||
user_ids_not_in_cache.add(user_id)
|
||||
|
||||
# add those users to the list to fetch over federation.
|
||||
for user_id in user_ids_not_in_cache:
|
||||
domain = get_domain_from_id(user_id)
|
||||
r = remote_queries_not_in_cache.setdefault(domain, {})
|
||||
r[user_id] = remote_queries[user_id]
|
||||
|
||||
# Now fetch any devices that we don't have in our cache
|
||||
@trace
|
||||
async def do_remote_query(destination):
|
||||
"""This is called when we are querying the device list of a user on
|
||||
a remote homeserver and their device list is not in the device list
|
||||
cache. If we share a room with this user and we're not querying for
|
||||
specific user we will update the cache with their device list.
|
||||
"""
|
||||
|
||||
destination_query = remote_queries_not_in_cache[destination]
|
||||
|
||||
# We first consider whether we wish to update the device list cache with
|
||||
# the users device list. We want to track a user's devices when the
|
||||
# authenticated user shares a room with the queried user and the query
|
||||
# has not specified a particular device.
|
||||
# If we update the cache for the queried user we remove them from further
|
||||
# queries. We use the more efficient batched query_client_keys for all
|
||||
# remaining users
|
||||
user_ids_updated = []
|
||||
for (user_id, device_list) in destination_query.items():
|
||||
if user_id in user_ids_updated:
|
||||
continue
|
||||
|
||||
if device_list:
|
||||
continue
|
||||
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
continue
|
||||
|
||||
# We've decided we're sharing a room with this user and should
|
||||
# probably be tracking their device lists. However, we haven't
|
||||
# done an initial sync on the device list so we do it now.
|
||||
try:
|
||||
if self._is_master:
|
||||
user_devices = await self.device_handler.device_list_updater.user_device_resync(
|
||||
user_id
|
||||
)
|
||||
else:
|
||||
user_devices = await self._user_device_resync_client(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
user_devices = user_devices["devices"]
|
||||
user_results = results.setdefault(user_id, {})
|
||||
for device in user_devices:
|
||||
user_results[device["device_id"]] = device["keys"]
|
||||
user_ids_updated.append(user_id)
|
||||
except Exception as e:
|
||||
failures[destination] = _exception_to_failure(e)
|
||||
|
||||
if len(destination_query) == len(user_ids_updated):
|
||||
# We've updated all the users in the query and we do not need to
|
||||
# make any further remote calls.
|
||||
return
|
||||
|
||||
# Remove all the users from the query which we have updated
|
||||
for user_id in user_ids_updated:
|
||||
destination_query.pop(user_id)
|
||||
|
||||
try:
|
||||
remote_result = await self.federation.query_client_keys(
|
||||
destination, {"device_keys": destination_query}, timeout=timeout
|
||||
)
|
||||
|
||||
for user_id, keys in remote_result["device_keys"].items():
|
||||
if user_id in destination_query:
|
||||
# First get local devices.
|
||||
# A map of destination -> failure response.
|
||||
failures = {} # type: Dict[str, JsonDict]
|
||||
results = {}
|
||||
if local_query:
|
||||
local_result = await self.query_local_devices(local_query)
|
||||
for user_id, keys in local_result.items():
|
||||
if user_id in local_query:
|
||||
results[user_id] = keys
|
||||
|
||||
if "master_keys" in remote_result:
|
||||
for user_id, key in remote_result["master_keys"].items():
|
||||
# Get cached cross-signing keys
|
||||
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
|
||||
device_keys_query, from_user_id
|
||||
)
|
||||
|
||||
# Now attempt to get any remote devices from our local cache.
|
||||
# A map of destination -> user ID -> device IDs.
|
||||
remote_queries_not_in_cache = (
|
||||
{}
|
||||
) # type: Dict[str, Dict[str, Iterable[str]]]
|
||||
if remote_queries:
|
||||
query_list = [] # type: List[Tuple[str, Optional[str]]]
|
||||
for user_id, device_ids in remote_queries.items():
|
||||
if device_ids:
|
||||
query_list.extend(
|
||||
(user_id, device_id) for device_id in device_ids
|
||||
)
|
||||
else:
|
||||
query_list.append((user_id, None))
|
||||
|
||||
(
|
||||
user_ids_not_in_cache,
|
||||
remote_results,
|
||||
) = await self.store.get_user_devices_from_cache(query_list)
|
||||
for user_id, devices in remote_results.items():
|
||||
user_devices = results.setdefault(user_id, {})
|
||||
for device_id, device in devices.items():
|
||||
keys = device.get("keys", None)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if keys:
|
||||
result = dict(keys)
|
||||
unsigned = result.setdefault("unsigned", {})
|
||||
if device_display_name:
|
||||
unsigned["device_display_name"] = device_display_name
|
||||
user_devices[device_id] = result
|
||||
|
||||
# check for missing cross-signing keys.
|
||||
for user_id in remote_queries.keys():
|
||||
cached_cross_master = user_id in cross_signing_keys["master_keys"]
|
||||
cached_cross_selfsigning = (
|
||||
user_id in cross_signing_keys["self_signing_keys"]
|
||||
)
|
||||
|
||||
# check if we are missing only one of cross-signing master or
|
||||
# self-signing key, but the other one is cached.
|
||||
# as we need both, this will issue a federation request.
|
||||
# if we don't have any of the keys, either the user doesn't have
|
||||
# cross-signing set up, or the cached device list
|
||||
# is not (yet) updated.
|
||||
if cached_cross_master ^ cached_cross_selfsigning:
|
||||
user_ids_not_in_cache.add(user_id)
|
||||
|
||||
# add those users to the list to fetch over federation.
|
||||
for user_id in user_ids_not_in_cache:
|
||||
domain = get_domain_from_id(user_id)
|
||||
r = remote_queries_not_in_cache.setdefault(domain, {})
|
||||
r[user_id] = remote_queries[user_id]
|
||||
|
||||
# Now fetch any devices that we don't have in our cache
|
||||
@trace
|
||||
async def do_remote_query(destination):
|
||||
"""This is called when we are querying the device list of a user on
|
||||
a remote homeserver and their device list is not in the device list
|
||||
cache. If we share a room with this user and we're not querying for
|
||||
specific user we will update the cache with their device list.
|
||||
"""
|
||||
|
||||
destination_query = remote_queries_not_in_cache[destination]
|
||||
|
||||
# We first consider whether we wish to update the device list cache with
|
||||
# the users device list. We want to track a user's devices when the
|
||||
# authenticated user shares a room with the queried user and the query
|
||||
# has not specified a particular device.
|
||||
# If we update the cache for the queried user we remove them from further
|
||||
# queries. We use the more efficient batched query_client_keys for all
|
||||
# remaining users
|
||||
user_ids_updated = []
|
||||
for (user_id, device_list) in destination_query.items():
|
||||
if user_id in user_ids_updated:
|
||||
continue
|
||||
|
||||
if device_list:
|
||||
continue
|
||||
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
continue
|
||||
|
||||
# We've decided we're sharing a room with this user and should
|
||||
# probably be tracking their device lists. However, we haven't
|
||||
# done an initial sync on the device list so we do it now.
|
||||
try:
|
||||
if self._is_master:
|
||||
user_devices = await self.device_handler.device_list_updater.user_device_resync(
|
||||
user_id
|
||||
)
|
||||
else:
|
||||
user_devices = await self._user_device_resync_client(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
user_devices = user_devices["devices"]
|
||||
user_results = results.setdefault(user_id, {})
|
||||
for device in user_devices:
|
||||
user_results[device["device_id"]] = device["keys"]
|
||||
user_ids_updated.append(user_id)
|
||||
except Exception as e:
|
||||
failures[destination] = _exception_to_failure(e)
|
||||
|
||||
if len(destination_query) == len(user_ids_updated):
|
||||
# We've updated all the users in the query and we do not need to
|
||||
# make any further remote calls.
|
||||
return
|
||||
|
||||
# Remove all the users from the query which we have updated
|
||||
for user_id in user_ids_updated:
|
||||
destination_query.pop(user_id)
|
||||
|
||||
try:
|
||||
remote_result = await self.federation.query_client_keys(
|
||||
destination, {"device_keys": destination_query}, timeout=timeout
|
||||
)
|
||||
|
||||
for user_id, keys in remote_result["device_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["master_keys"][user_id] = key
|
||||
results[user_id] = keys
|
||||
|
||||
if "self_signing_keys" in remote_result:
|
||||
for user_id, key in remote_result["self_signing_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||
if "master_keys" in remote_result:
|
||||
for user_id, key in remote_result["master_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["master_keys"][user_id] = key
|
||||
|
||||
except Exception as e:
|
||||
failure = _exception_to_failure(e)
|
||||
failures[destination] = failure
|
||||
set_tag("error", True)
|
||||
set_tag("reason", failure)
|
||||
if "self_signing_keys" in remote_result:
|
||||
for user_id, key in remote_result["self_signing_keys"].items():
|
||||
if user_id in destination_query:
|
||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(do_remote_query, destination)
|
||||
for destination in remote_queries_not_in_cache
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
except Exception as e:
|
||||
failure = _exception_to_failure(e)
|
||||
failures[destination] = failure
|
||||
set_tag("error", True)
|
||||
set_tag("reason", failure)
|
||||
|
||||
ret = {"device_keys": results, "failures": failures}
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(do_remote_query, destination)
|
||||
for destination in remote_queries_not_in_cache
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
ret.update(cross_signing_keys)
|
||||
ret = {"device_keys": results, "failures": failures}
|
||||
|
||||
return ret
|
||||
ret.update(cross_signing_keys)
|
||||
|
||||
return ret
|
||||
|
||||
async def get_cross_signing_keys_from_cache(
|
||||
self, query: Iterable[str], from_user_id: Optional[str]
|
||||
|
|
|
@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
|
|||
async def on_POST(self, request):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
device_id = requester.device_id
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
|
||||
result = await self.e2e_keys_handler.query_devices(
|
||||
body, timeout, user_id, device_id
|
||||
)
|
||||
return 200, result
|
||||
|
||||
|
||||
|
|
|
@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
|
||||
|
||||
devices = self.get_success(
|
||||
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
|
||||
self.handler.query_devices(
|
||||
{"device_keys": {local_user: []}}, 0, local_user, "device123"
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||
|
||||
|
@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
|
||||
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
|
||||
devices = self.get_success(
|
||||
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
|
||||
self.handler.query_devices(
|
||||
{"device_keys": {local_user: []}}, 0, local_user, "device123"
|
||||
)
|
||||
)
|
||||
del devices["device_keys"][local_user]["abc"]["unsigned"]
|
||||
del devices["device_keys"][local_user]["def"]["unsigned"]
|
||||
|
@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# fetch the signed keys/devices and make sure that the signatures are there
|
||||
ret = self.get_success(
|
||||
self.handler.query_devices(
|
||||
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
|
||||
{"device_keys": {local_user: [], other_user: []}},
|
||||
0,
|
||||
local_user,
|
||||
"device123",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue