Safe async event cache (#13308)
Fix race conditions in the async cache invalidation logic, by separating the async & local invalidation calls and ensuring any async call i executed first. Signed off by Nick @ Beeper (@Fizzadar).
This commit is contained in:
parent
7864f33e28
commit
2ee0b6ef4b
|
@ -0,0 +1 @@
|
|||
Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar).
|
|
@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||
where they may not have the cache.
|
||||
|
||||
Note that this function does not invalidate any remote caches, only the
|
||||
local in-memory ones. Any remote invalidation must be performed before
|
||||
calling this.
|
||||
|
||||
Args:
|
||||
cache_name
|
||||
key: Entry to invalidate. If None then invalidates the entire
|
||||
|
@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||
if key is None:
|
||||
cache.invalidate_all()
|
||||
else:
|
||||
cache.invalidate(tuple(key))
|
||||
# Prefer any local-only invalidation method. Invalidating any non-local
|
||||
# cache must be be done before this.
|
||||
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
|
||||
invalidate_method(tuple(key))
|
||||
|
||||
|
||||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||
|
|
|
@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
|
|||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
|
@ -57,7 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
|
|||
*,
|
||||
txn_name: Optional[str] = None,
|
||||
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||
async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
|
||||
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||
) -> "LoggingTransaction":
|
||||
if not txn_name:
|
||||
|
@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
|
|||
name=txn_name,
|
||||
database_engine=self.engine,
|
||||
after_callbacks=after_callbacks,
|
||||
async_after_callbacks=async_after_callbacks,
|
||||
exception_callbacks=exception_callbacks,
|
||||
)
|
||||
|
||||
|
@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
|
|||
|
||||
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
|
||||
_AsyncCallbackListEntry = Tuple[
|
||||
Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
|
||||
]
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
@ -227,6 +233,10 @@ class LoggingTransaction:
|
|||
that have been added by `call_after` which should be run on
|
||||
successful completion of the transaction. None indicates that no
|
||||
callbacks should be allowed to be scheduled to run.
|
||||
async_after_callbacks: A list that asynchronous callbacks will be appended
|
||||
to by `async_call_after` which should run, before after_callbacks, on
|
||||
successful completion of the transaction. None indicates that no
|
||||
callbacks should be allowed to be scheduled to run.
|
||||
exception_callbacks: A list that callbacks will be appended
|
||||
to that have been added by `call_on_exception` which should be run
|
||||
if transaction ends with an error. None indicates that no callbacks
|
||||
|
@ -238,6 +248,7 @@ class LoggingTransaction:
|
|||
"name",
|
||||
"database_engine",
|
||||
"after_callbacks",
|
||||
"async_after_callbacks",
|
||||
"exception_callbacks",
|
||||
]
|
||||
|
||||
|
@ -247,12 +258,14 @@ class LoggingTransaction:
|
|||
name: str,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
after_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
|
||||
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
):
|
||||
self.txn = txn
|
||||
self.name = name
|
||||
self.database_engine = database_engine
|
||||
self.after_callbacks = after_callbacks
|
||||
self.async_after_callbacks = async_after_callbacks
|
||||
self.exception_callbacks = exception_callbacks
|
||||
|
||||
def call_after(
|
||||
|
@ -277,6 +290,28 @@ class LoggingTransaction:
|
|||
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||
|
||||
def async_call_after(
|
||||
self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
"""Call the given asynchronous callback on the main twisted thread after
|
||||
the transaction has finished (but before those added in `call_after`).
|
||||
|
||||
Mostly used to invalidate remote caches after transactions.
|
||||
|
||||
Note that transactions may be retried a few times if they encounter database
|
||||
errors such as serialization failures. Callbacks given to `async_call_after`
|
||||
will accumulate across transaction attempts and will _all_ be called once a
|
||||
transaction attempt succeeds, regardless of whether previous transaction
|
||||
attempts failed. Otherwise, if all transaction attempts fail, all
|
||||
`call_on_exception` callbacks will be run instead.
|
||||
"""
|
||||
# if self.async_after_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
assert self.async_after_callbacks is not None
|
||||
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||
self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||
|
||||
def call_on_exception(
|
||||
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
|
@ -574,6 +609,7 @@ class DatabasePool:
|
|||
conn: LoggingDatabaseConnection,
|
||||
desc: str,
|
||||
after_callbacks: List[_CallbackListEntry],
|
||||
async_after_callbacks: List[_AsyncCallbackListEntry],
|
||||
exception_callbacks: List[_CallbackListEntry],
|
||||
func: Callable[Concatenate[LoggingTransaction, P], R],
|
||||
*args: P.args,
|
||||
|
@ -597,6 +633,7 @@ class DatabasePool:
|
|||
conn
|
||||
desc
|
||||
after_callbacks
|
||||
async_after_callbacks
|
||||
exception_callbacks
|
||||
func
|
||||
*args
|
||||
|
@ -659,6 +696,7 @@ class DatabasePool:
|
|||
cursor = conn.cursor(
|
||||
txn_name=name,
|
||||
after_callbacks=after_callbacks,
|
||||
async_after_callbacks=async_after_callbacks,
|
||||
exception_callbacks=exception_callbacks,
|
||||
)
|
||||
try:
|
||||
|
@ -798,6 +836,7 @@ class DatabasePool:
|
|||
|
||||
async def _runInteraction() -> R:
|
||||
after_callbacks: List[_CallbackListEntry] = []
|
||||
async_after_callbacks: List[_AsyncCallbackListEntry] = []
|
||||
exception_callbacks: List[_CallbackListEntry] = []
|
||||
|
||||
if not current_context():
|
||||
|
@ -809,6 +848,7 @@ class DatabasePool:
|
|||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
async_after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
|
@ -817,15 +857,17 @@ class DatabasePool:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# We order these assuming that async functions call out to external
|
||||
# systems (e.g. to invalidate a cache) and the sync functions make these
|
||||
# changes on any local in-memory caches/similar, and thus must be second.
|
||||
for async_callback, async_args, async_kwargs in async_after_callbacks:
|
||||
await async_callback(*async_args, **async_kwargs)
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
await maybe_awaitable(after_callback(*after_args, **after_kwargs))
|
||||
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
return cast(R, result)
|
||||
except Exception:
|
||||
for exception_callback, after_args, after_kwargs in exception_callbacks:
|
||||
await maybe_awaitable(
|
||||
exception_callback(*after_args, **after_kwargs)
|
||||
)
|
||||
exception_callback(*after_args, **after_kwargs)
|
||||
raise
|
||||
|
||||
# To handle cancellation, we ensure that `after_callback`s and
|
||||
|
|
|
@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
|||
# changed its content in the database. We can't call
|
||||
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
|
||||
# right type.
|
||||
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
|
||||
self.invalidate_get_event_cache_after_txn(txn, event.event_id)
|
||||
# Send that invalidation to replication so that other workers also invalidate
|
||||
# the event cache.
|
||||
self._send_invalidation_to_replication(
|
||||
|
|
|
@ -1293,7 +1293,7 @@ class PersistEventsStore:
|
|||
depth_updates: Dict[str, int] = {}
|
||||
for event, context in events_and_contexts:
|
||||
# Remove the any existing cache entries for the event_ids
|
||||
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
|
||||
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
|
||||
# Then update the `stream_ordering` position to mark the latest
|
||||
# event as the front of the room. This should not be done for
|
||||
# backfilled events because backfilled events have negative
|
||||
|
@ -1675,7 +1675,7 @@ class PersistEventsStore:
|
|||
(cache_entry.event.event_id,), cache_entry
|
||||
)
|
||||
|
||||
txn.call_after(prefill)
|
||||
txn.async_call_after(prefill)
|
||||
|
||||
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
"""Invalidate the caches for the redacted event.
|
||||
|
@ -1684,7 +1684,7 @@ class PersistEventsStore:
|
|||
_invalidate_caches_for_event.
|
||||
"""
|
||||
assert event.redacts is not None
|
||||
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
|
||||
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
|
||||
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
|
||||
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
|
||||
|
||||
|
|
|
@ -712,17 +712,41 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return event_entry_map
|
||||
|
||||
async def _invalidate_get_event_cache(self, event_id: str) -> None:
|
||||
# First we invalidate the asynchronous cache instance. This may include
|
||||
# out-of-process caches such as Redis/memcache. Once complete we can
|
||||
# invalidate any in memory cache. The ordering is important here to
|
||||
# ensure we don't pull in any remote invalid value after we invalidate
|
||||
# the in-memory cache.
|
||||
def invalidate_get_event_cache_after_txn(
|
||||
self, txn: LoggingTransaction, event_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Prepares a database transaction to invalidate the get event cache for a given
|
||||
event ID when executed successfully. This is achieved by attaching two callbacks
|
||||
to the transaction, one to invalidate the async cache and one for the in memory
|
||||
sync cache (importantly called in that order).
|
||||
|
||||
Arguments:
|
||||
txn: the database transaction to attach the callbacks to
|
||||
event_id: the event ID to be invalidated from caches
|
||||
"""
|
||||
|
||||
txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
|
||||
txn.call_after(self._invalidate_local_get_event_cache, event_id)
|
||||
|
||||
async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
|
||||
"""
|
||||
Invalidates an event in the asyncronous get event cache, which may be remote.
|
||||
|
||||
Arguments:
|
||||
event_id: the event ID to invalidate
|
||||
"""
|
||||
|
||||
await self._get_event_cache.invalidate((event_id,))
|
||||
self._event_ref.pop(event_id, None)
|
||||
self._current_event_fetches.pop(event_id, None)
|
||||
|
||||
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
|
||||
"""
|
||||
Invalidates an event in local in-memory get event caches.
|
||||
|
||||
Arguments:
|
||||
event_id: the event ID to invalidate
|
||||
"""
|
||||
|
||||
self._get_event_cache.invalidate_local((event_id,))
|
||||
self._event_ref.pop(event_id, None)
|
||||
self._current_event_fetches.pop(event_id, None)
|
||||
|
@ -958,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
}
|
||||
|
||||
row_dict = self.db_pool.new_transaction(
|
||||
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
|
||||
conn,
|
||||
"do_fetch",
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
self._fetch_event_rows,
|
||||
events_to_fetch,
|
||||
)
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
|
|
|
@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
|
|||
"initialise_mau_threepids",
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
self._initialise_reserved_users,
|
||||
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
||||
)
|
||||
|
|
|
@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.have_seen_event, (room_id, event_id)
|
||||
)
|
||||
txn.call_after(self._invalidate_get_event_cache, event_id)
|
||||
self.invalidate_get_event_cache_after_txn(txn, event_id)
|
||||
|
||||
logger.info("[purge] done")
|
||||
|
||||
|
|
Loading…
Reference in New Issue