diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 4cf5549143..aff69c5f83 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -101,10 +101,10 @@ class Keyring(object): server_name(str): The name of the server to fetch a key for. keys_ids (list of str): The key_ids to check for. """ - cached = yield self.store.get_server_verify_key(server_name, key_ids[0]) + cached = yield self.store.get_server_verify_keys(server_name, key_ids) if cached: - defer.returnValue(cached) + defer.returnValue(cached[0]) return download = self.key_downloads.get(server_name) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 88a5642924..2902e35181 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -71,24 +71,24 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) - @cached(num_args=2) + @cached() @defer.inlineCallbacks - def get_server_verify_key(self, server_name, key_id): - key_bytes = yield self._simple_select_one_onecol( + def get_all_server_verify_keys(self, server_name): + rows = yield self._simple_select_list( table="server_signature_keys", keyvalues={ "server_name": server_name, - "key_id": key_id }, - retcol="verify_key", - desc="get_server_verify_key", - allow_none=True, + retcols=["key_id", "verify_key"], + desc="get_all_server_verify_keys", ) - if key_bytes: - defer.returnValue(decode_verify_key_bytes(key_id, str(key_bytes))) - else: - defer.returnValue(None) + defer.returnValue({ + row["key_id"]: decode_verify_key_bytes( + row["key_id"], str(row["verify_key"]) + ) + for row in rows + }) @defer.inlineCallbacks def get_server_verify_keys(self, server_name, key_ids): @@ -100,23 +100,8 @@ class KeyStore(SQLBaseStore): Returns: (list of VerifyKey): The verification keys. """ - sql = ( - "SELECT key_id, verify_key FROM server_signature_keys" - " WHERE server_name = ?" - " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" - ) - - rows = yield self._execute_and_decode( - "get_server_verify_keys", sql, server_name, *key_ids - ) - - keys = [] - for row in rows: - key_id = row["key_id"] - key_bytes = row["verify_key"] - key = decode_verify_key_bytes(key_id, str(key_bytes)) - keys.append(key) - defer.returnValue(keys) + keys = yield self.get_all_server_verify_keys(server_name) + defer.returnValue([keys[k] for k in key_ids if k in keys]) @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, @@ -129,12 +114,11 @@ class KeyStore(SQLBaseStore): ts_now_ms (int): The time now in milliseconds verification_key (VerifyKey): The NACL verify key. """ - key_id = "%s:%s" % (verify_key.alg, verify_key.version) yield self._simple_upsert( table="server_signature_keys", keyvalues={ "server_name": server_name, - "key_id": key_id, + "key_id": "%s:%s" % (verify_key.alg, verify_key.version), }, values={ "from_server": from_server, @@ -144,7 +128,7 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) - self.get_server_verify_key.invalidate(server_name, key_id) + self.get_all_server_verify_keys.invalidate(server_name) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes):