Limit the number of in-flight /keys/query requests from a single device. (#10144)
This commit is contained in:
parent
1bf83a191b
commit
11846dff8c
|
@ -0,0 +1 @@
|
||||||
|
Limit the number of in-flight `/keys/query` requests from a single device.
|
|
@ -79,9 +79,15 @@ class E2eKeysHandler:
|
||||||
"client_keys", self.on_federation_query_client_keys
|
"client_keys", self.on_federation_query_client_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Limit the number of in-flight requests from a single device.
|
||||||
|
self._query_devices_linearizer = Linearizer(
|
||||||
|
name="query_devices",
|
||||||
|
max_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def query_devices(
|
async def query_devices(
|
||||||
self, query_body: JsonDict, timeout: int, from_user_id: str
|
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Handle a device key query from a client
|
"""Handle a device key query from a client
|
||||||
|
|
||||||
|
@ -105,191 +111,197 @@ class E2eKeysHandler:
|
||||||
from_user_id: the user making the query. This is used when
|
from_user_id: the user making the query. This is used when
|
||||||
adding cross-signing signatures to limit what signatures users
|
adding cross-signing signatures to limit what signatures users
|
||||||
can see.
|
can see.
|
||||||
|
from_device_id: the device making the query. This is used to limit
|
||||||
|
the number of in-flight queries at a time.
|
||||||
"""
|
"""
|
||||||
|
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||||
|
device_keys_query = query_body.get(
|
||||||
|
"device_keys", {}
|
||||||
|
) # type: Dict[str, Iterable[str]]
|
||||||
|
|
||||||
device_keys_query = query_body.get(
|
# separate users by domain.
|
||||||
"device_keys", {}
|
# make a map from domain to user_id to device_ids
|
||||||
) # type: Dict[str, Iterable[str]]
|
local_query = {}
|
||||||
|
remote_queries = {}
|
||||||
|
|
||||||
# separate users by domain.
|
for user_id, device_ids in device_keys_query.items():
|
||||||
# make a map from domain to user_id to device_ids
|
# we use UserID.from_string to catch invalid user ids
|
||||||
local_query = {}
|
if self.is_mine(UserID.from_string(user_id)):
|
||||||
remote_queries = {}
|
local_query[user_id] = device_ids
|
||||||
|
|
||||||
for user_id, device_ids in device_keys_query.items():
|
|
||||||
# we use UserID.from_string to catch invalid user ids
|
|
||||||
if self.is_mine(UserID.from_string(user_id)):
|
|
||||||
local_query[user_id] = device_ids
|
|
||||||
else:
|
|
||||||
remote_queries[user_id] = device_ids
|
|
||||||
|
|
||||||
set_tag("local_key_query", local_query)
|
|
||||||
set_tag("remote_key_query", remote_queries)
|
|
||||||
|
|
||||||
# First get local devices.
|
|
||||||
# A map of destination -> failure response.
|
|
||||||
failures = {} # type: Dict[str, JsonDict]
|
|
||||||
results = {}
|
|
||||||
if local_query:
|
|
||||||
local_result = await self.query_local_devices(local_query)
|
|
||||||
for user_id, keys in local_result.items():
|
|
||||||
if user_id in local_query:
|
|
||||||
results[user_id] = keys
|
|
||||||
|
|
||||||
# Get cached cross-signing keys
|
|
||||||
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
|
|
||||||
device_keys_query, from_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now attempt to get any remote devices from our local cache.
|
|
||||||
# A map of destination -> user ID -> device IDs.
|
|
||||||
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
|
|
||||||
if remote_queries:
|
|
||||||
query_list = [] # type: List[Tuple[str, Optional[str]]]
|
|
||||||
for user_id, device_ids in remote_queries.items():
|
|
||||||
if device_ids:
|
|
||||||
query_list.extend((user_id, device_id) for device_id in device_ids)
|
|
||||||
else:
|
else:
|
||||||
query_list.append((user_id, None))
|
remote_queries[user_id] = device_ids
|
||||||
|
|
||||||
(
|
set_tag("local_key_query", local_query)
|
||||||
user_ids_not_in_cache,
|
set_tag("remote_key_query", remote_queries)
|
||||||
remote_results,
|
|
||||||
) = await self.store.get_user_devices_from_cache(query_list)
|
|
||||||
for user_id, devices in remote_results.items():
|
|
||||||
user_devices = results.setdefault(user_id, {})
|
|
||||||
for device_id, device in devices.items():
|
|
||||||
keys = device.get("keys", None)
|
|
||||||
device_display_name = device.get("device_display_name", None)
|
|
||||||
if keys:
|
|
||||||
result = dict(keys)
|
|
||||||
unsigned = result.setdefault("unsigned", {})
|
|
||||||
if device_display_name:
|
|
||||||
unsigned["device_display_name"] = device_display_name
|
|
||||||
user_devices[device_id] = result
|
|
||||||
|
|
||||||
# check for missing cross-signing keys.
|
# First get local devices.
|
||||||
for user_id in remote_queries.keys():
|
# A map of destination -> failure response.
|
||||||
cached_cross_master = user_id in cross_signing_keys["master_keys"]
|
failures = {} # type: Dict[str, JsonDict]
|
||||||
cached_cross_selfsigning = (
|
results = {}
|
||||||
user_id in cross_signing_keys["self_signing_keys"]
|
if local_query:
|
||||||
)
|
local_result = await self.query_local_devices(local_query)
|
||||||
|
for user_id, keys in local_result.items():
|
||||||
# check if we are missing only one of cross-signing master or
|
if user_id in local_query:
|
||||||
# self-signing key, but the other one is cached.
|
|
||||||
# as we need both, this will issue a federation request.
|
|
||||||
# if we don't have any of the keys, either the user doesn't have
|
|
||||||
# cross-signing set up, or the cached device list
|
|
||||||
# is not (yet) updated.
|
|
||||||
if cached_cross_master ^ cached_cross_selfsigning:
|
|
||||||
user_ids_not_in_cache.add(user_id)
|
|
||||||
|
|
||||||
# add those users to the list to fetch over federation.
|
|
||||||
for user_id in user_ids_not_in_cache:
|
|
||||||
domain = get_domain_from_id(user_id)
|
|
||||||
r = remote_queries_not_in_cache.setdefault(domain, {})
|
|
||||||
r[user_id] = remote_queries[user_id]
|
|
||||||
|
|
||||||
# Now fetch any devices that we don't have in our cache
|
|
||||||
@trace
|
|
||||||
async def do_remote_query(destination):
|
|
||||||
"""This is called when we are querying the device list of a user on
|
|
||||||
a remote homeserver and their device list is not in the device list
|
|
||||||
cache. If we share a room with this user and we're not querying for
|
|
||||||
specific user we will update the cache with their device list.
|
|
||||||
"""
|
|
||||||
|
|
||||||
destination_query = remote_queries_not_in_cache[destination]
|
|
||||||
|
|
||||||
# We first consider whether we wish to update the device list cache with
|
|
||||||
# the users device list. We want to track a user's devices when the
|
|
||||||
# authenticated user shares a room with the queried user and the query
|
|
||||||
# has not specified a particular device.
|
|
||||||
# If we update the cache for the queried user we remove them from further
|
|
||||||
# queries. We use the more efficient batched query_client_keys for all
|
|
||||||
# remaining users
|
|
||||||
user_ids_updated = []
|
|
||||||
for (user_id, device_list) in destination_query.items():
|
|
||||||
if user_id in user_ids_updated:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if device_list:
|
|
||||||
continue
|
|
||||||
|
|
||||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
|
||||||
if not room_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# We've decided we're sharing a room with this user and should
|
|
||||||
# probably be tracking their device lists. However, we haven't
|
|
||||||
# done an initial sync on the device list so we do it now.
|
|
||||||
try:
|
|
||||||
if self._is_master:
|
|
||||||
user_devices = await self.device_handler.device_list_updater.user_device_resync(
|
|
||||||
user_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
user_devices = await self._user_device_resync_client(
|
|
||||||
user_id=user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
user_devices = user_devices["devices"]
|
|
||||||
user_results = results.setdefault(user_id, {})
|
|
||||||
for device in user_devices:
|
|
||||||
user_results[device["device_id"]] = device["keys"]
|
|
||||||
user_ids_updated.append(user_id)
|
|
||||||
except Exception as e:
|
|
||||||
failures[destination] = _exception_to_failure(e)
|
|
||||||
|
|
||||||
if len(destination_query) == len(user_ids_updated):
|
|
||||||
# We've updated all the users in the query and we do not need to
|
|
||||||
# make any further remote calls.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove all the users from the query which we have updated
|
|
||||||
for user_id in user_ids_updated:
|
|
||||||
destination_query.pop(user_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
remote_result = await self.federation.query_client_keys(
|
|
||||||
destination, {"device_keys": destination_query}, timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
for user_id, keys in remote_result["device_keys"].items():
|
|
||||||
if user_id in destination_query:
|
|
||||||
results[user_id] = keys
|
results[user_id] = keys
|
||||||
|
|
||||||
if "master_keys" in remote_result:
|
# Get cached cross-signing keys
|
||||||
for user_id, key in remote_result["master_keys"].items():
|
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
|
||||||
|
device_keys_query, from_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now attempt to get any remote devices from our local cache.
|
||||||
|
# A map of destination -> user ID -> device IDs.
|
||||||
|
remote_queries_not_in_cache = (
|
||||||
|
{}
|
||||||
|
) # type: Dict[str, Dict[str, Iterable[str]]]
|
||||||
|
if remote_queries:
|
||||||
|
query_list = [] # type: List[Tuple[str, Optional[str]]]
|
||||||
|
for user_id, device_ids in remote_queries.items():
|
||||||
|
if device_ids:
|
||||||
|
query_list.extend(
|
||||||
|
(user_id, device_id) for device_id in device_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_list.append((user_id, None))
|
||||||
|
|
||||||
|
(
|
||||||
|
user_ids_not_in_cache,
|
||||||
|
remote_results,
|
||||||
|
) = await self.store.get_user_devices_from_cache(query_list)
|
||||||
|
for user_id, devices in remote_results.items():
|
||||||
|
user_devices = results.setdefault(user_id, {})
|
||||||
|
for device_id, device in devices.items():
|
||||||
|
keys = device.get("keys", None)
|
||||||
|
device_display_name = device.get("device_display_name", None)
|
||||||
|
if keys:
|
||||||
|
result = dict(keys)
|
||||||
|
unsigned = result.setdefault("unsigned", {})
|
||||||
|
if device_display_name:
|
||||||
|
unsigned["device_display_name"] = device_display_name
|
||||||
|
user_devices[device_id] = result
|
||||||
|
|
||||||
|
# check for missing cross-signing keys.
|
||||||
|
for user_id in remote_queries.keys():
|
||||||
|
cached_cross_master = user_id in cross_signing_keys["master_keys"]
|
||||||
|
cached_cross_selfsigning = (
|
||||||
|
user_id in cross_signing_keys["self_signing_keys"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if we are missing only one of cross-signing master or
|
||||||
|
# self-signing key, but the other one is cached.
|
||||||
|
# as we need both, this will issue a federation request.
|
||||||
|
# if we don't have any of the keys, either the user doesn't have
|
||||||
|
# cross-signing set up, or the cached device list
|
||||||
|
# is not (yet) updated.
|
||||||
|
if cached_cross_master ^ cached_cross_selfsigning:
|
||||||
|
user_ids_not_in_cache.add(user_id)
|
||||||
|
|
||||||
|
# add those users to the list to fetch over federation.
|
||||||
|
for user_id in user_ids_not_in_cache:
|
||||||
|
domain = get_domain_from_id(user_id)
|
||||||
|
r = remote_queries_not_in_cache.setdefault(domain, {})
|
||||||
|
r[user_id] = remote_queries[user_id]
|
||||||
|
|
||||||
|
# Now fetch any devices that we don't have in our cache
|
||||||
|
@trace
|
||||||
|
async def do_remote_query(destination):
|
||||||
|
"""This is called when we are querying the device list of a user on
|
||||||
|
a remote homeserver and their device list is not in the device list
|
||||||
|
cache. If we share a room with this user and we're not querying for
|
||||||
|
specific user we will update the cache with their device list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
destination_query = remote_queries_not_in_cache[destination]
|
||||||
|
|
||||||
|
# We first consider whether we wish to update the device list cache with
|
||||||
|
# the users device list. We want to track a user's devices when the
|
||||||
|
# authenticated user shares a room with the queried user and the query
|
||||||
|
# has not specified a particular device.
|
||||||
|
# If we update the cache for the queried user we remove them from further
|
||||||
|
# queries. We use the more efficient batched query_client_keys for all
|
||||||
|
# remaining users
|
||||||
|
user_ids_updated = []
|
||||||
|
for (user_id, device_list) in destination_query.items():
|
||||||
|
if user_id in user_ids_updated:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if device_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||||
|
if not room_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We've decided we're sharing a room with this user and should
|
||||||
|
# probably be tracking their device lists. However, we haven't
|
||||||
|
# done an initial sync on the device list so we do it now.
|
||||||
|
try:
|
||||||
|
if self._is_master:
|
||||||
|
user_devices = await self.device_handler.device_list_updater.user_device_resync(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
user_devices = await self._user_device_resync_client(
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
user_devices = user_devices["devices"]
|
||||||
|
user_results = results.setdefault(user_id, {})
|
||||||
|
for device in user_devices:
|
||||||
|
user_results[device["device_id"]] = device["keys"]
|
||||||
|
user_ids_updated.append(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
failures[destination] = _exception_to_failure(e)
|
||||||
|
|
||||||
|
if len(destination_query) == len(user_ids_updated):
|
||||||
|
# We've updated all the users in the query and we do not need to
|
||||||
|
# make any further remote calls.
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove all the users from the query which we have updated
|
||||||
|
for user_id in user_ids_updated:
|
||||||
|
destination_query.pop(user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
remote_result = await self.federation.query_client_keys(
|
||||||
|
destination, {"device_keys": destination_query}, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, keys in remote_result["device_keys"].items():
|
||||||
if user_id in destination_query:
|
if user_id in destination_query:
|
||||||
cross_signing_keys["master_keys"][user_id] = key
|
results[user_id] = keys
|
||||||
|
|
||||||
if "self_signing_keys" in remote_result:
|
if "master_keys" in remote_result:
|
||||||
for user_id, key in remote_result["self_signing_keys"].items():
|
for user_id, key in remote_result["master_keys"].items():
|
||||||
if user_id in destination_query:
|
if user_id in destination_query:
|
||||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
cross_signing_keys["master_keys"][user_id] = key
|
||||||
|
|
||||||
except Exception as e:
|
if "self_signing_keys" in remote_result:
|
||||||
failure = _exception_to_failure(e)
|
for user_id, key in remote_result["self_signing_keys"].items():
|
||||||
failures[destination] = failure
|
if user_id in destination_query:
|
||||||
set_tag("error", True)
|
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||||
set_tag("reason", failure)
|
|
||||||
|
|
||||||
await make_deferred_yieldable(
|
except Exception as e:
|
||||||
defer.gatherResults(
|
failure = _exception_to_failure(e)
|
||||||
[
|
failures[destination] = failure
|
||||||
run_in_background(do_remote_query, destination)
|
set_tag("error", True)
|
||||||
for destination in remote_queries_not_in_cache
|
set_tag("reason", failure)
|
||||||
],
|
|
||||||
consumeErrors=True,
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = {"device_keys": results, "failures": failures}
|
await make_deferred_yieldable(
|
||||||
|
defer.gatherResults(
|
||||||
|
[
|
||||||
|
run_in_background(do_remote_query, destination)
|
||||||
|
for destination in remote_queries_not_in_cache
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
)
|
||||||
|
|
||||||
ret.update(cross_signing_keys)
|
ret = {"device_keys": results, "failures": failures}
|
||||||
|
|
||||||
return ret
|
ret.update(cross_signing_keys)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
async def get_cross_signing_keys_from_cache(
|
async def get_cross_signing_keys_from_cache(
|
||||||
self, query: Iterable[str], from_user_id: Optional[str]
|
self, query: Iterable[str], from_user_id: Optional[str]
|
||||||
|
|
|
@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
|
||||||
async def on_POST(self, request):
|
async def on_POST(self, request):
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
device_id = requester.device_id
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
|
result = await self.e2e_keys_handler.query_devices(
|
||||||
|
body, timeout, user_id, device_id
|
||||||
|
)
|
||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
|
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
|
||||||
|
|
||||||
devices = self.get_success(
|
devices = self.get_success(
|
||||||
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
|
self.handler.query_devices(
|
||||||
|
{"device_keys": {local_user: []}}, 0, local_user, "device123"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||||
|
|
||||||
|
@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
|
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
|
||||||
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
|
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
|
||||||
devices = self.get_success(
|
devices = self.get_success(
|
||||||
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
|
self.handler.query_devices(
|
||||||
|
{"device_keys": {local_user: []}}, 0, local_user, "device123"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
del devices["device_keys"][local_user]["abc"]["unsigned"]
|
del devices["device_keys"][local_user]["abc"]["unsigned"]
|
||||||
del devices["device_keys"][local_user]["def"]["unsigned"]
|
del devices["device_keys"][local_user]["def"]["unsigned"]
|
||||||
|
@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
# fetch the signed keys/devices and make sure that the signatures are there
|
# fetch the signed keys/devices and make sure that the signatures are there
|
||||||
ret = self.get_success(
|
ret = self.get_success(
|
||||||
self.handler.query_devices(
|
self.handler.query_devices(
|
||||||
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
|
{"device_keys": {local_user: [], other_user: []}},
|
||||||
|
0,
|
||||||
|
local_user,
|
||||||
|
"device123",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue