From 2ee0b6ef4b78bada535beb30301cf0e01cbb7d81 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 19 Jul 2022 13:25:29 +0200 Subject: [PATCH] Safe async event cache (#13308) Fix race conditions in the async cache invalidation logic, by separating the async & local invalidation calls and ensuring any async call i executed first. Signed off by Nick @ Beeper (@Fizzadar). --- changelog.d/13308.misc | 1 + synapse/storage/_base.py | 9 +++- synapse/storage/database.py | 54 ++++++++++++++++--- .../storage/databases/main/censor_events.py | 2 +- synapse/storage/databases/main/events.py | 6 +-- .../storage/databases/main/events_worker.py | 48 +++++++++++++---- .../databases/main/monthly_active_users.py | 1 + .../storage/databases/main/purge_events.py | 2 +- 8 files changed, 102 insertions(+), 21 deletions(-) create mode 100644 changelog.d/13308.misc diff --git a/changelog.d/13308.misc b/changelog.d/13308.misc new file mode 100644 index 0000000000..7f8ec0815f --- /dev/null +++ b/changelog.d/13308.misc @@ -0,0 +1 @@ +Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b8c8dcd76b..a2f8310388 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta): cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. + Note that this function does not invalidate any remote caches, only the + local in-memory ones. Any remote invalidation must be performed before + calling this. + Args: cache_name key: Entry to invalidate. If None then invalidates the entire @@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta): if key is None: cache.invalidate_all() else: - cache.invalidate(tuple(key)) + # Prefer any local-only invalidation method. Invalidating any non-local + # cache must be be done before this. + invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) + invalidate_method(tuple(key)) def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6a6d0dcd73..ea672ff89e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -23,6 +23,7 @@ from time import monotonic as monotonic_time from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Collection, Dict, @@ -57,7 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.util.async_helpers import delay_cancellation, maybe_awaitable +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -168,6 +169,7 @@ class LoggingDatabaseConnection: *, txn_name: Optional[str] = None, after_callbacks: Optional[List["_CallbackListEntry"]] = None, + async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None, exception_callbacks: Optional[List["_CallbackListEntry"]] = None, ) -> "LoggingTransaction": if not txn_name: @@ -178,6 +180,7 @@ class LoggingDatabaseConnection: name=txn_name, database_engine=self.engine, after_callbacks=after_callbacks, + async_after_callbacks=async_after_callbacks, exception_callbacks=exception_callbacks, ) @@ -209,6 +212,9 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] +_AsyncCallbackListEntry = Tuple[ + Callable[..., Awaitable], Tuple[object, ...], Dict[str, object] +] P = ParamSpec("P") R = TypeVar("R") @@ -227,6 +233,10 @@ class LoggingTransaction: that have been added by `call_after` which should be run on successful completion of the transaction. None indicates that no callbacks should be allowed to be scheduled to run. + async_after_callbacks: A list that asynchronous callbacks will be appended + to by `async_call_after` which should run, before after_callbacks, on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. exception_callbacks: A list that callbacks will be appended to that have been added by `call_on_exception` which should be run if transaction ends with an error. None indicates that no callbacks @@ -238,6 +248,7 @@ class LoggingTransaction: "name", "database_engine", "after_callbacks", + "async_after_callbacks", "exception_callbacks", ] @@ -247,12 +258,14 @@ class LoggingTransaction: name: str, database_engine: BaseDatabaseEngine, after_callbacks: Optional[List[_CallbackListEntry]] = None, + async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None, exception_callbacks: Optional[List[_CallbackListEntry]] = None, ): self.txn = txn self.name = name self.database_engine = database_engine self.after_callbacks = after_callbacks + self.async_after_callbacks = async_after_callbacks self.exception_callbacks = exception_callbacks def call_after( @@ -277,6 +290,28 @@ class LoggingTransaction: # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + def async_call_after( + self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs + ) -> None: + """Call the given asynchronous callback on the main twisted thread after + the transaction has finished (but before those added in `call_after`). + + Mostly used to invalidate remote caches after transactions. + + Note that transactions may be retried a few times if they encounter database + errors such as serialization failures. Callbacks given to `async_call_after` + will accumulate across transaction attempts and will _all_ be called once a + transaction attempt succeeds, regardless of whether previous transaction + attempts failed. Otherwise, if all transaction attempts fail, all + `call_on_exception` callbacks will be run instead. + """ + # if self.async_after_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.async_after_callbacks is not None + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs ) -> None: @@ -574,6 +609,7 @@ class DatabasePool: conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], + async_after_callbacks: List[_AsyncCallbackListEntry], exception_callbacks: List[_CallbackListEntry], func: Callable[Concatenate[LoggingTransaction, P], R], *args: P.args, @@ -597,6 +633,7 @@ class DatabasePool: conn desc after_callbacks + async_after_callbacks exception_callbacks func *args @@ -659,6 +696,7 @@ class DatabasePool: cursor = conn.cursor( txn_name=name, after_callbacks=after_callbacks, + async_after_callbacks=async_after_callbacks, exception_callbacks=exception_callbacks, ) try: @@ -798,6 +836,7 @@ class DatabasePool: async def _runInteraction() -> R: after_callbacks: List[_CallbackListEntry] = [] + async_after_callbacks: List[_AsyncCallbackListEntry] = [] exception_callbacks: List[_CallbackListEntry] = [] if not current_context(): @@ -809,6 +848,7 @@ class DatabasePool: self.new_transaction, desc, after_callbacks, + async_after_callbacks, exception_callbacks, func, *args, @@ -817,15 +857,17 @@ class DatabasePool: **kwargs, ) + # We order these assuming that async functions call out to external + # systems (e.g. to invalidate a cache) and the sync functions make these + # changes on any local in-memory caches/similar, and thus must be second. + for async_callback, async_args, async_kwargs in async_after_callbacks: + await async_callback(*async_args, **async_kwargs) for after_callback, after_args, after_kwargs in after_callbacks: - await maybe_awaitable(after_callback(*after_args, **after_kwargs)) - + after_callback(*after_args, **after_kwargs) return cast(R, result) except Exception: for exception_callback, after_args, after_kwargs in exception_callbacks: - await maybe_awaitable( - exception_callback(*after_args, **after_kwargs) - ) + exception_callback(*after_args, **after_kwargs) raise # To handle cancellation, we ensure that `after_callback`s and diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index fd3fc298b3..58177ecec1 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # changed its content in the database. We can't call # self._invalidate_cache_and_stream because self.get_event_cache isn't of the # right type. - txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + self.invalidate_get_event_cache_after_txn(txn, event.event_id) # Send that invalidation to replication so that other workers also invalidate # the event cache. self._send_invalidation_to_replication( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fa2266ba20..156e1bd5ab 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1293,7 +1293,7 @@ class PersistEventsStore: depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids - txn.call_after(self.store._invalidate_get_event_cache, event.event_id) + self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) # Then update the `stream_ordering` position to mark the latest # event as the front of the room. This should not be done for # backfilled events because backfilled events have negative @@ -1675,7 +1675,7 @@ class PersistEventsStore: (cache_entry.event.event_id,), cache_entry ) - txn.call_after(prefill) + txn.async_call_after(prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: """Invalidate the caches for the redacted event. @@ -1684,7 +1684,7 @@ class PersistEventsStore: _invalidate_caches_for_event. """ assert event.redacts is not None - txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f3935bfead..4435373146 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -712,17 +712,41 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - async def _invalidate_get_event_cache(self, event_id: str) -> None: - # First we invalidate the asynchronous cache instance. This may include - # out-of-process caches such as Redis/memcache. Once complete we can - # invalidate any in memory cache. The ordering is important here to - # ensure we don't pull in any remote invalid value after we invalidate - # the in-memory cache. + def invalidate_get_event_cache_after_txn( + self, txn: LoggingTransaction, event_id: str + ) -> None: + """ + Prepares a database transaction to invalidate the get event cache for a given + event ID when executed successfully. This is achieved by attaching two callbacks + to the transaction, one to invalidate the async cache and one for the in memory + sync cache (importantly called in that order). + + Arguments: + txn: the database transaction to attach the callbacks to + event_id: the event ID to be invalidated from caches + """ + + txn.async_call_after(self._invalidate_async_get_event_cache, event_id) + txn.call_after(self._invalidate_local_get_event_cache, event_id) + + async def _invalidate_async_get_event_cache(self, event_id: str) -> None: + """ + Invalidates an event in the asyncronous get event cache, which may be remote. + + Arguments: + event_id: the event ID to invalidate + """ + await self._get_event_cache.invalidate((event_id,)) - self._event_ref.pop(event_id, None) - self._current_event_fetches.pop(event_id, None) def _invalidate_local_get_event_cache(self, event_id: str) -> None: + """ + Invalidates an event in local in-memory get event caches. + + Arguments: + event_id: the event ID to invalidate + """ + self._get_event_cache.invalidate_local((event_id,)) self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) @@ -958,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore): } row_dict = self.db_pool.new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch + conn, + "do_fetch", + [], + [], + [], + self._fetch_event_rows, + events_to_fetch, ) # We only want to resolve deferreds from the main thread diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 9a63f953fb..efd136a864 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): "initialise_mau_threepids", [], [], + [], self._initialise_reserved_users, hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 6d42276503..f6822707e4 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) - txn.call_after(self._invalidate_get_event_cache, event_id) + self.invalidate_get_event_cache_after_txn(txn, event_id) logger.info("[purge] done")