Modify StoreKeyFetcher to read from server_keys_json. (#15417)

Before this change:

* `PerspectivesKeyFetcher` and `ServerKeyFetcher` write to `server_keys_json`.
* `PerspectivesKeyFetcher` also writes to `server_signature_keys`.
* `StoreKeyFetcher` reads from `server_signature_keys`.

After this change:

* `PerspectivesKeyFetcher` and `ServerKeyFetcher` write to `server_keys_json`.
* `PerspectivesKeyFetcher` also writes to `server_signature_keys`.
* `StoreKeyFetcher` reads from `server_keys_json`.

This results in `StoreKeyFetcher` now using the results from `ServerKeyFetcher`
in addition to those from `PerspectivesKeyFetcher`, i.e. keys which are directly
fetched from a server will now be pulled from the database instead of refetched.

An additional minor change is included to avoid creating a `PerspectivesKeyFetcher`
(and checking it) if no `trusted_key_servers` are configured.

The overall impact of this should be better usage of cached results:

* If a server has no trusted key servers configured then it should reduce how often keys
  are fetched.
* if a server's trusted key server does not have a requested server's keys cached then it
  should reduce how often keys are directly fetched.
This commit is contained in:
Patrick Cloke 2023-04-20 12:30:32 -04:00 committed by GitHub
parent ae69d69525
commit 5e024a0645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 162 additions and 76 deletions

1
changelog.d/15417.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where cached key results which were directly fetched would not be properly re-used.

View File

@ -150,17 +150,18 @@ class Keyring:
def __init__( def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
): ):
self.clock = hs.get_clock()
if key_fetchers is None: if key_fetchers is None:
key_fetchers = ( # Always fetch keys from the database.
# Fetch keys from the database. mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)]
StoreKeyFetcher(hs), # Fetch keys from configured trusted key servers, if any exist.
# Fetch keys from a configured Perspectives server. key_servers = hs.config.key.key_servers
PerspectivesKeyFetcher(hs), if key_servers:
# Fetch keys from the origin server directly. mutable_key_fetchers.append(PerspectivesKeyFetcher(hs))
ServerKeyFetcher(hs), # Finally, fetch keys from the origin server directly.
) mutable_key_fetchers.append(ServerKeyFetcher(hs))
self._key_fetchers: Iterable[KeyFetcher] = tuple(mutable_key_fetchers)
else:
self._key_fetchers = key_fetchers self._key_fetchers = key_fetchers
self._fetch_keys_queue: BatchingQueue[ self._fetch_keys_queue: BatchingQueue[
@ -510,7 +511,7 @@ class StoreKeyFetcher(KeyFetcher):
for key_id in queue_value.key_ids for key_id in queue_value.key_ids
) )
res = await self.store.get_server_verify_keys(key_ids_to_fetch) res = await self.store.get_server_keys_json(key_ids_to_fetch)
keys: Dict[str, Dict[str, FetchKeyResult]] = {} keys: Dict[str, Dict[str, FetchKeyResult]] = {}
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
@ -522,7 +523,6 @@ class BaseV2KeyFetcher(KeyFetcher):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.config = hs.config
async def process_v2_response( async def process_v2_response(
self, from_server: str, response_json: JsonDict, time_added_ms: int self, from_server: str, response_json: JsonDict, time_added_ms: int
@ -626,7 +626,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
self.key_servers = self.config.key.key_servers self.key_servers = hs.config.key.key_servers
async def _fetch_keys( async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest] self, keys_to_fetch: List[_FetchKeyRequest]
@ -775,7 +775,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
keys.setdefault(server_name, {}).update(processed_response) keys.setdefault(server_name, {}).update(processed_response)
await self.store.store_server_verify_keys( await self.store.store_server_signature_keys(
perspective_name, time_now_ms, added_keys perspective_name, time_now_ms, added_keys
) )

View File

@ -155,7 +155,7 @@ class RemoteKey(RestServlet):
for key_id in key_ids: for key_id in key_ids:
store_queries.append((server_name, key_id, None)) store_queries.append((server_name, key_id, None))
cached = await self.store.get_server_keys_json(store_queries) cached = await self.store.get_server_keys_json_for_remote(store_queries)
json_results: Set[bytes] = set() json_results: Set[bytes] = set()

View File

@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import json
import logging import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
@ -36,15 +38,16 @@ class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys""" """Persistence for signature verification keys"""
@cached() @cached()
def _get_server_verify_key( def _get_server_signature_key(
self, server_name_and_key_id: Tuple[str, str] self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult: ) -> FetchKeyResult:
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" cached_method_name="_get_server_signature_key",
list_name="server_name_and_key_ids",
) )
async def get_server_verify_keys( async def get_server_signature_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]] self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]: ) -> Dict[Tuple[str, str], FetchKeyResult]:
""" """
@ -62,10 +65,12 @@ class KeyStore(SQLBaseStore):
"""Processes a batch of keys to fetch, and adds the result to `keys`.""" """Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch) # batch_iter always returns tuples so it's safe to do len(batch)
sql = ( sql = """
"SELECT server_name, key_id, verify_key, ts_valid_until_ms " SELECT server_name, key_id, verify_key, ts_valid_until_ms
"FROM server_signature_keys WHERE 1=0" FROM server_signature_keys WHERE 1=0
) + " OR (server_name=? AND key_id=?)" * len(batch) """ + " OR (server_name=? AND key_id=?)" * len(
batch
)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
@ -89,9 +94,9 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch) _get_keys(txn, batch)
return keys return keys
return await self.db_pool.runInteraction("get_server_verify_keys", _txn) return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
async def store_server_verify_keys( async def store_server_signature_keys(
self, self,
from_server: str, from_server: str,
ts_added_ms: int, ts_added_ms: int,
@ -119,7 +124,7 @@ class KeyStore(SQLBaseStore):
) )
) )
# invalidate takes a tuple corresponding to the params of # invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one # _get_server_signature_key. _get_server_signature_key only takes one
# param, which is itself the 2-tuple (server_name, key_id). # param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id)) invalidations.append((server_name, key_id))
@ -134,10 +139,10 @@ class KeyStore(SQLBaseStore):
"verify_key", "verify_key",
), ),
value_values=value_values, value_values=value_values,
desc="store_server_verify_keys", desc="store_server_signature_keys",
) )
invalidate = self._get_server_verify_key.invalidate invalidate = self._get_server_signature_key.invalidate
for i in invalidations: for i in invalidations:
invalidate((i,)) invalidate((i,))
@ -180,7 +185,75 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json", desc="store_server_keys_json",
) )
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate((((server_name, key_id),)))
@cached()
def _get_server_keys_json(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids"
)
async def get_server_keys_json( async def get_server_keys_json(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
A map from (server_name, key_id) -> FetchKeyResult, or None if the
key is unknown
"""
keys = {}
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
sql = """
SELECT server_name, key_id, key_json, ts_valid_until_ms
FROM server_keys_json WHERE 1=0
""" + " OR (server_name=? AND key_id=?)" * len(
batch
)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn:
if ts_valid_until_ms is None:
# Old keys may be stored with a ts_valid_until_ms of null,
# in which case we treat this as if it was set to `0`, i.e.
# it won't match key requests that define a minimum
# `ts_valid_until_ms`.
ts_valid_until_ms = 0
# The entire signed JSON response is stored in server_keys_json,
# fetch out the bits needed.
key_json = json.loads(bytes(key_json_bytes))
key_base64 = key_json["verify_keys"][key_id]["key"]
keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(
key_id, decode_base64(key_base64)
),
valid_until_ts=ts_valid_until_ms,
)
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids. """Retrieve the key json for a list of server_keys and key ids.
@ -188,8 +261,10 @@ class KeyStore(SQLBaseStore):
that server, key_id, and source triplet entry will be an empty list. that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response. used in an HTTP response.
Args: Args:
server_keys: List of (server_name, key_id, source) triplets. server_keys: List of (server_name, key_id, source) triplets.
Returns: Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts A mapping from (server_name, key_id, source) triplets to a list of dicts
""" """

View File

@ -190,10 +190,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key("1") key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys( r = self.hs.get_datastores().main.store_server_keys_json(
"server9", "server9",
int(time.time() * 1000), get_key_id(key1),
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), 1000)}, from_server="test",
ts_now_ms=int(time.time() * 1000),
ts_expires_ms=1000,
# The entire response gets signed & stored, just include the bits we
# care about.
key_json_bytes=canonicaljson.encode_canonical_json(
{
"verify_keys": {
get_key_id(key1): {
"key": encode_verify_key_base64(get_verify_key(key1))
}
}
}
),
) )
self.get_success(r) self.get_success(r)
@ -280,17 +293,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher = Mock() mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
)
key1 = signedjson.key.generate_signing_key("1") key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys( r = self.hs.get_datastores().main.store_server_signature_keys(
"server9", "server9",
int(time.time() * 1000), int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this # None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted # API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys. # to 0 when fetched in KeyStore.get_server_signature_keys.
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type] {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
) )
self.get_success(r) self.get_success(r)
@ -298,27 +307,12 @@ class KeyringTestCase(unittest.HomeserverTestCase):
json1: JsonDict = {} json1: JsonDict = {}
signedjson.sign.sign_json(json1, "server9", key1) signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
d = kr.verify_json_for_server("server9", {}, 0)
self.get_failure(d, SynapseError)
# should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails.
d = kr.verify_json_for_server("server9", json1, 500)
self.get_failure(d, SynapseError)
# We expect the keyring tried to refetch the key once.
mock_fetcher.get_keys.assert_called_once_with(
"server9", [get_key_id(key1)], 500
)
# should succeed on a signed object with a 0 minimum_valid_until_ms # should succeed on a signed object with a 0 minimum_valid_until_ms
d = kr.verify_json_for_server( d = self.hs.get_datastores().main.get_server_signature_keys(
"server9", [("server9", get_key_id(key1))]
json1,
0,
) )
self.get_success(d) result = self.get_success(d)
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
def test_verify_json_dedupes_key_requests(self) -> None: def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped.""" """Two requests for the same key should be deduped."""
@ -464,7 +458,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None) lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
)
) )
res_keys = key_json[lookup_triplet] res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1) self.assertEqual(len(res_keys), 1)
@ -582,7 +578,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None) lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
)
) )
res_keys = key_json[lookup_triplet] res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1) self.assertEqual(len(res_keys), 1)
@ -703,7 +701,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None) lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
)
) )
res_keys = key_json[lookup_triplet] res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1) self.assertEqual(len(res_keys), 1)

View File

@ -37,13 +37,13 @@ KEY_2 = decode_verify_key_base64(
class KeyStoreTestCase(tests.unittest.HomeserverTestCase): class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self) -> None: def test_get_server_signature_keys(self) -> None:
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
key_id_1 = "ed25519:key1" key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2" key_id_2 = "ed25519:KEY_ID_2"
self.get_success( self.get_success(
store.store_server_verify_keys( store.store_server_signature_keys(
"from_server", "from_server",
10, 10,
{ {
@ -54,7 +54,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
) )
res = self.get_success( res = self.get_success(
store.get_server_verify_keys( store.get_server_signature_keys(
[ [
("server1", key_id_1), ("server1", key_id_1),
("server1", key_id_2), ("server1", key_id_2),
@ -87,7 +87,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_2 = "ed25519:key2" key_id_2 = "ed25519:key2"
self.get_success( self.get_success(
store.store_server_verify_keys( store.store_server_signature_keys(
"from_server", "from_server",
0, 0,
{ {
@ -98,7 +98,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
) )
res = self.get_success( res = self.get_success(
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
) )
self.assertEqual(len(res.keys()), 2) self.assertEqual(len(res.keys()), 2)
@ -111,20 +111,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(res2.valid_until_ts, 200) self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit # we should be able to look up the same thing again without a db hit
res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)])) res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)]))
self.assertEqual(len(res.keys()), 1) self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key( new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2") signedjson.key.generate_signing_key("key2")
) )
d = store.store_server_verify_keys( d = store.store_server_signature_keys(
"from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)} "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
) )
self.get_success(d) self.get_success(d)
res = self.get_success( res = self.get_success(
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
) )
self.assertEqual(len(res.keys()), 2) self.assertEqual(len(res.keys()), 2)

View File

@ -69,7 +69,6 @@ from synapse.logging.context import (
) )
from synapse.rest import RegisterServletsFunc from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -848,15 +847,23 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
self.get_success( self.get_success(
hs.get_datastores().main.store_server_verify_keys( hs.get_datastores().main.store_server_keys_json(
self.OTHER_SERVER_NAME,
verify_key_id,
from_server=self.OTHER_SERVER_NAME, from_server=self.OTHER_SERVER_NAME,
ts_added_ms=clock.time_msec(), ts_now_ms=clock.time_msec(),
verify_keys={ ts_expires_ms=clock.time_msec() + 10000,
(self.OTHER_SERVER_NAME, verify_key_id): FetchKeyResult( key_json_bytes=canonicaljson.encode_canonical_json(
verify_key=verify_key, {
valid_until_ts=clock.time_msec() + 10000, "verify_keys": {
verify_key_id: {
"key": signedjson.key.encode_verify_key_base64(
verify_key
)
}
}
}
), ),
},
) )
) )

View File

@ -131,6 +131,9 @@ def default_config(
# the test signing key is just an arbitrary ed25519 key to keep the config # the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy # parser happy
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
# Disable trusted key servers, otherwise unit tests might try to actually
# reach out to matrix.org.
"trusted_key_servers": [],
"event_cache_size": 1, "event_cache_size": 1,
"enable_registration": True, "enable_registration": True,
"enable_registration_captcha": False, "enable_registration_captcha": False,