Limit the number of in-flight /keys/query requests from a single device. (#10144)

This commit is contained in:
Patrick Cloke 2021-06-09 07:05:32 -04:00 committed by GitHub
parent 1bf83a191b
commit 11846dff8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 198 additions and 175 deletions

1
changelog.d/10144.misc Normal file
View File

@ -0,0 +1 @@
Limit the number of in-flight `/keys/query` requests from a single device.

View File

@ -79,9 +79,15 @@ class E2eKeysHandler:
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )
# Limit the number of in-flight requests from a single device.
self._query_devices_linearizer = Linearizer(
name="query_devices",
max_count=10,
)
@trace @trace
async def query_devices( async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict: ) -> JsonDict:
"""Handle a device key query from a client """Handle a device key query from a client
@ -105,191 +111,197 @@ class E2eKeysHandler:
from_user_id: the user making the query. This is used when from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users adding cross-signing signatures to limit what signatures users
can see. can see.
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
""" """
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
device_keys_query = query_body.get( # separate users by domain.
"device_keys", {} # make a map from domain to user_id to device_ids
) # type: Dict[str, Iterable[str]] local_query = {}
remote_queries = {}
# separate users by domain. for user_id, device_ids in device_keys_query.items():
# make a map from domain to user_id to device_ids # we use UserID.from_string to catch invalid user ids
local_query = {} if self.is_mine(UserID.from_string(user_id)):
remote_queries = {} local_query[user_id] = device_ids
for user_id, device_ids in device_keys_query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
# Get cached cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else: else:
query_list.append((user_id, None)) remote_queries[user_id] = device_ids
( set_tag("local_key_query", local_query)
user_ids_not_in_cache, set_tag("remote_key_query", remote_queries)
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
# check for missing cross-signing keys. # First get local devices.
for user_id in remote_queries.keys(): # A map of destination -> failure response.
cached_cross_master = user_id in cross_signing_keys["master_keys"] failures = {} # type: Dict[str, JsonDict]
cached_cross_selfsigning = ( results = {}
user_id in cross_signing_keys["self_signing_keys"] if local_query:
) local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
# check if we are missing only one of cross-signing master or if user_id in local_query:
# self-signing key, but the other one is cached.
# as we need both, this will issue a federation request.
# if we don't have any of the keys, either the user doesn't have
# cross-signing set up, or the cached device list
# is not (yet) updated.
if cached_cross_master ^ cached_cross_selfsigning:
user_ids_not_in_cache.add(user_id)
# add those users to the list to fetch over federation.
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache with their device list.
"""
destination_query = remote_queries_not_in_cache[destination]
# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue
if device_list:
continue
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
user_devices = await self._user_device_resync_client(
user_id=user_id
)
user_devices = user_devices["devices"]
user_results = results.setdefault(user_id, {})
for device in user_devices:
user_results[device["device_id"]] = device["keys"]
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
try:
remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys results[user_id] = keys
if "master_keys" in remote_result: # Get cached cross-signing keys
for user_id, key in remote_result["master_keys"].items(): cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
(user_id, device_id) for device_id in device_ids
)
else:
query_list.append((user_id, None))
(
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
# check for missing cross-signing keys.
for user_id in remote_queries.keys():
cached_cross_master = user_id in cross_signing_keys["master_keys"]
cached_cross_selfsigning = (
user_id in cross_signing_keys["self_signing_keys"]
)
# check if we are missing only one of cross-signing master or
# self-signing key, but the other one is cached.
# as we need both, this will issue a federation request.
# if we don't have any of the keys, either the user doesn't have
# cross-signing set up, or the cached device list
# is not (yet) updated.
if cached_cross_master ^ cached_cross_selfsigning:
user_ids_not_in_cache.add(user_id)
# add those users to the list to fetch over federation.
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache with their device list.
"""
destination_query = remote_queries_not_in_cache[destination]
# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue
if device_list:
continue
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
user_devices = await self._user_device_resync_client(
user_id=user_id
)
user_devices = user_devices["devices"]
user_results = results.setdefault(user_id, {})
for device in user_devices:
user_results[device["device_id"]] = device["keys"]
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
try:
remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query: if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key results[user_id] = keys
if "self_signing_keys" in remote_result: if "master_keys" in remote_result:
for user_id, key in remote_result["self_signing_keys"].items(): for user_id, key in remote_result["master_keys"].items():
if user_id in destination_query: if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key cross_signing_keys["master_keys"][user_id] = key
except Exception as e: if "self_signing_keys" in remote_result:
failure = _exception_to_failure(e) for user_id, key in remote_result["self_signing_keys"].items():
failures[destination] = failure if user_id in destination_query:
set_tag("error", True) cross_signing_keys["self_signing_keys"][user_id] = key
set_tag("reason", failure)
await make_deferred_yieldable( except Exception as e:
defer.gatherResults( failure = _exception_to_failure(e)
[ failures[destination] = failure
run_in_background(do_remote_query, destination) set_tag("error", True)
for destination in remote_queries_not_in_cache set_tag("reason", failure)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
ret = {"device_keys": results, "failures": failures} await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
ret.update(cross_signing_keys) ret = {"device_keys": results, "failures": failures}
return ret ret.update(cross_signing_keys)
return ret
async def get_cross_signing_keys_from_cache( async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str] self, query: Iterable[str], from_user_id: Optional[str]

View File

@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
async def on_POST(self, request): async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id) result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
return 200, result return 200, result

View File

@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2)) self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
devices = self.get_success( devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
) )
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = self.get_success( devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
) )
del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"]
@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# fetch the signed keys/devices and make sure that the signatures are there # fetch the signed keys/devices and make sure that the signatures are there
ret = self.get_success( ret = self.get_success(
self.handler.query_devices( self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user {"device_keys": {local_user: [], other_user: []}},
0,
local_user,
"device123",
) )
) )