Limit the number of in-flight /keys/query requests from a single device. (#10144)

This commit is contained in:
Patrick Cloke 2021-06-09 07:05:32 -04:00 committed by GitHub
parent 1bf83a191b
commit 11846dff8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 198 additions and 175 deletions

1
changelog.d/10144.misc Normal file
View File

@ -0,0 +1 @@
Limit the number of in-flight `/keys/query` requests from a single device.

View File

@ -79,9 +79,15 @@ class E2eKeysHandler:
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )
# Limit the number of in-flight requests from a single device.
self._query_devices_linearizer = Linearizer(
name="query_devices",
max_count=10,
)
@trace @trace
async def query_devices( async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict: ) -> JsonDict:
"""Handle a device key query from a client """Handle a device key query from a client
@ -105,8 +111,10 @@ class E2eKeysHandler:
from_user_id: the user making the query. This is used when from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users adding cross-signing signatures to limit what signatures users
can see. can see.
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
""" """
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get( device_keys_query = query_body.get(
"device_keys", {} "device_keys", {}
) # type: Dict[str, Iterable[str]] ) # type: Dict[str, Iterable[str]]
@ -143,12 +151,16 @@ class E2eKeysHandler:
# Now attempt to get any remote devices from our local cache. # Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs. # A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries: if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]] query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items(): for user_id, device_ids in remote_queries.items():
if device_ids: if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids) query_list.extend(
(user_id, device_id) for device_id in device_ids
)
else: else:
query_list.append((user_id, None)) query_list.append((user_id, None))

View File

@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
async def on_POST(self, request): async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id) result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
return 200, result return 200, result

View File

@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2)) self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
devices = self.get_success( devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
) )
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = self.get_success( devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
) )
del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"]
@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# fetch the signed keys/devices and make sure that the signatures are there # fetch the signed keys/devices and make sure that the signatures are there
ret = self.get_success( ret = self.get_success(
self.handler.query_devices( self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user {"device_keys": {local_user: [], other_user: []}},
0,
local_user,
"device123",
) )
) )