Handle duplicate OTK uploads racing (#17241)

Currently this causes one of then to 500.
This commit is contained in:
Erik Johnston 2024-05-29 11:16:00 +01:00 committed by GitHub
parent bb5a692946
commit 94ef2f4f5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 33 deletions

1
changelog.d/17241.bugfix Normal file
View File

@ -0,0 +1 @@
Fix handling of duplicate concurrent uploading of device one-time-keys.

View File

@ -53,6 +53,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
class E2eKeysHandler: class E2eKeysHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.config = hs.config self.config = hs.config
@ -62,6 +65,7 @@ class E2eKeysHandler:
self._appservice_handler = hs.get_application_service_handler() self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._worker_lock_handler = hs.get_worker_locks_handler()
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
@ -855,45 +859,53 @@ class E2eKeysHandler:
async def _upload_one_time_keys_for_user( async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None: ) -> None:
logger.info( # We take out a lock so that we don't have to worry about a client
"Adding one_time_keys %r for device %r for user %r at %d", # sending duplicate requests.
one_time_keys.keys(), lock_key = f"{user_id}_{device_id}"
device_id, async with self._worker_lock_handler.acquire_lock(
user_id, ONE_TIME_KEY_UPLOAD, lock_key
time_now, ):
) logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
device_id,
user_id,
time_now,
)
# make a list of (alg, id, key) tuples # make a list of (alg, id, key) tuples
key_list = [] key_list = []
for key_id, key_obj in one_time_keys.items(): for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":") algorithm, key_id = key_id.split(":")
key_list.append((algorithm, key_id, key_obj)) key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys. # First we check if we have already persisted any of the keys.
existing_key_map = await self.store.get_e2e_one_time_keys( existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list] user_id, device_id, [k_id for _, k_id, _ in key_list]
) )
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples. new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list: for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None) ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json: if ex_json:
if not _one_time_keys_match(ex_json, key): if not _one_time_keys_match(ex_json, key):
raise SynapseError( raise SynapseError(
400, 400,
( (
"One time key %s:%s already exists. " "One time key %s:%s already exists. "
"Old key: %s; new key: %r" "Old key: %s; new key: %r"
)
% (algorithm, key_id, ex_json, key),
) )
% (algorithm, key_id, ex_json, key), else:
new_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
) )
else:
new_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) await self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)
async def upload_signing_keys_for_user( async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict self, user_id: str, keys: JsonDict