Add cache invalidation across workers to module API (#13667)
Signed-off-by: Mathieu Velten <mathieuv@matrix.org>
This commit is contained in:
parent
16e1a9d9a7
commit
6bd8763804
|
@ -0,0 +1 @@
|
||||||
|
Add cache invalidation across workers to module API.
|
|
@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
|
||||||
self, fullname: str
|
self, fullname: str
|
||||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||||
if fullname.startswith(
|
if fullname.startswith(
|
||||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
"synapse.util.caches.descriptors.CachedFunction.__call__"
|
||||||
) or fullname.startswith(
|
) or fullname.startswith(
|
||||||
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
|
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
|
||||||
):
|
):
|
||||||
|
@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):
|
||||||
|
|
||||||
|
|
||||||
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||||
"""Fixes the `_CachedFunction.__call__` signature to be correct.
|
"""Fixes the `CachedFunction.__call__` signature to be correct.
|
||||||
|
|
||||||
It already has *almost* the correct signature, except:
|
It already has *almost* the correct signature, except:
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ from synapse.types import (
|
||||||
)
|
)
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import CachedFunction, cached
|
||||||
from synapse.util.frozenutils import freeze
|
from synapse.util.frozenutils import freeze
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -836,6 +836,37 @@ class ModuleApi:
|
||||||
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
|
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def register_cached_function(self, cached_func: CachedFunction) -> None:
|
||||||
|
"""Register a cached function that should be invalidated across workers.
|
||||||
|
Invalidation local to a worker can be done directly using `cached_func.invalidate`,
|
||||||
|
however invalidation that needs to go to other workers needs to call `invalidate_cache`
|
||||||
|
on the module API instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cached_function: The cached function that will be registered to receive invalidation
|
||||||
|
locally and from other workers.
|
||||||
|
"""
|
||||||
|
self._store.register_external_cached_function(
|
||||||
|
f"{cached_func.__module__}.{cached_func.__name__}", cached_func
|
||||||
|
)
|
||||||
|
|
||||||
|
async def invalidate_cache(
|
||||||
|
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
|
||||||
|
) -> None:
|
||||||
|
"""Invalidate a cache entry of a cached function across workers. The cached function
|
||||||
|
needs to be registered on all workers first with `register_cached_function`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cached_function: The cached function that needs an invalidation
|
||||||
|
keys: keys of the entry to invalidate, usually matching the arguments of the
|
||||||
|
cached function.
|
||||||
|
"""
|
||||||
|
cached_func.invalidate(keys)
|
||||||
|
await self._store.send_invalidation_to_replication(
|
||||||
|
f"{cached_func.__module__}.{cached_func.__name__}",
|
||||||
|
keys,
|
||||||
|
)
|
||||||
|
|
||||||
async def complete_sso_login_async(
|
async def complete_sso_login_async(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from abc import ABCMeta
|
from abc import ABCMeta
|
||||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
|
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
from synapse.util.caches.descriptors import CachedFunction
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
self.database_engine = database.engine
|
self.database_engine = database.engine
|
||||||
self.db_pool = database
|
self.db_pool = database
|
||||||
|
|
||||||
|
self.external_cached_functions: Dict[str, CachedFunction] = {}
|
||||||
|
|
||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self,
|
self,
|
||||||
stream_name: str,
|
stream_name: str,
|
||||||
|
@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
|
|
||||||
def _attempt_to_invalidate_cache(
|
def _attempt_to_invalidate_cache(
|
||||||
self, cache_name: str, key: Optional[Collection[Any]]
|
self, cache_name: str, key: Optional[Collection[Any]]
|
||||||
) -> None:
|
) -> bool:
|
||||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||||
where they may not have the cache.
|
where they may not have the cache.
|
||||||
|
@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
try:
|
try:
|
||||||
cache = getattr(self, cache_name)
|
cache = getattr(self, cache_name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# We probably haven't pulled in the cache in this worker,
|
# Check if an externally defined module cache has been registered
|
||||||
# which is fine.
|
cache = self.external_cached_functions.get(cache_name)
|
||||||
return
|
if not cache:
|
||||||
|
# We probably haven't pulled in the cache in this worker,
|
||||||
|
# which is fine.
|
||||||
|
return False
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
cache.invalidate_all()
|
cache.invalidate_all()
|
||||||
|
@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
|
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
|
||||||
invalidate_method(tuple(key))
|
invalidate_method(tuple(key))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def register_external_cached_function(
|
||||||
|
self, cache_name: str, func: CachedFunction
|
||||||
|
) -> None:
|
||||||
|
self.external_cached_functions[cache_name] = func
|
||||||
|
|
||||||
|
|
||||||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -33,7 +33,7 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from synapse.util.caches.descriptors import _CachedFunction
|
from synapse.util.caches.descriptors import CachedFunction
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
return
|
return
|
||||||
|
|
||||||
cache_func.invalidate(keys)
|
cache_func.invalidate(keys)
|
||||||
await self.db_pool.runInteraction(
|
await self.send_invalidation_to_replication(
|
||||||
"invalidate_cache_and_stream",
|
|
||||||
self._send_invalidation_to_replication,
|
|
||||||
cache_func.__name__,
|
cache_func.__name__,
|
||||||
keys,
|
keys,
|
||||||
)
|
)
|
||||||
|
@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
def _invalidate_cache_and_stream(
|
def _invalidate_cache_and_stream(
|
||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
cache_func: _CachedFunction,
|
cache_func: CachedFunction,
|
||||||
keys: Tuple[Any, ...],
|
keys: Tuple[Any, ...],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||||
|
@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
||||||
|
|
||||||
def _invalidate_all_cache_and_stream(
|
def _invalidate_all_cache_and_stream(
|
||||||
self, txn: LoggingTransaction, cache_func: _CachedFunction
|
self, txn: LoggingTransaction, cache_func: CachedFunction
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Invalidates the entire cache and adds it to the cache stream so slaves
|
"""Invalidates the entire cache and adds it to the cache stream so slaves
|
||||||
will know to invalidate their caches.
|
will know to invalidate their caches.
|
||||||
|
@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
txn, CURRENT_STATE_CACHE_NAME, [room_id]
|
txn, CURRENT_STATE_CACHE_NAME, [room_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def send_invalidation_to_replication(
|
||||||
|
self, cache_name: str, keys: Optional[Collection[Any]]
|
||||||
|
) -> None:
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"send_invalidation_to_replication",
|
||||||
|
self._send_invalidation_to_replication,
|
||||||
|
cache_name,
|
||||||
|
keys,
|
||||||
|
)
|
||||||
|
|
||||||
def _send_invalidation_to_replication(
|
def _send_invalidation_to_replication(
|
||||||
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
|
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any]
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
class _CachedFunction(Generic[F]):
|
class CachedFunction(Generic[F]):
|
||||||
invalidate: Any = None
|
invalidate: Any = None
|
||||||
invalidate_all: Any = None
|
invalidate_all: Any = None
|
||||||
prefill: Any = None
|
prefill: Any = None
|
||||||
|
@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
return ret2
|
return ret2
|
||||||
|
|
||||||
wrapped = cast(_CachedFunction, _wrapped)
|
wrapped = cast(CachedFunction, _wrapped)
|
||||||
wrapped.cache = cache
|
wrapped.cache = cache
|
||||||
obj.__dict__[self.name] = wrapped
|
obj.__dict__[self.name] = wrapped
|
||||||
|
|
||||||
|
@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
return make_deferred_yieldable(ret)
|
return make_deferred_yieldable(ret)
|
||||||
|
|
||||||
wrapped = cast(_CachedFunction, _wrapped)
|
wrapped = cast(CachedFunction, _wrapped)
|
||||||
|
|
||||||
if self.num_args == 1:
|
if self.num_args == 1:
|
||||||
assert not self.tree
|
assert not self.tree
|
||||||
|
@ -572,7 +572,7 @@ def cached(
|
||||||
iterable: bool = False,
|
iterable: bool = False,
|
||||||
prune_unread_entries: bool = True,
|
prune_unread_entries: bool = True,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
) -> Callable[[F], _CachedFunction[F]]:
|
) -> Callable[[F], CachedFunction[F]]:
|
||||||
func = lambda orig: DeferredCacheDescriptor(
|
func = lambda orig: DeferredCacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
|
@ -585,7 +585,7 @@ def cached(
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
return cast(Callable[[F], CachedFunction[F]], func)
|
||||||
|
|
||||||
|
|
||||||
def cachedList(
|
def cachedList(
|
||||||
|
@ -594,7 +594,7 @@ def cachedList(
|
||||||
list_name: str,
|
list_name: str,
|
||||||
num_args: Optional[int] = None,
|
num_args: Optional[int] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
) -> Callable[[F], _CachedFunction[F]]:
|
) -> Callable[[F], CachedFunction[F]]:
|
||||||
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
|
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
|
||||||
|
|
||||||
Used to do batch lookups for an already created cache. One of the arguments
|
Used to do batch lookups for an already created cache. One of the arguments
|
||||||
|
@ -631,7 +631,7 @@ def cachedList(
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
return cast(Callable[[F], CachedFunction[F]], func)
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_key_builder(
|
def _get_cache_key_builder(
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
from synapse.module_api import cached
|
||||||
|
|
||||||
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FIRST_VALUE = "one"
|
||||||
|
SECOND_VALUE = "two"
|
||||||
|
|
||||||
|
KEY = "mykey"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCache:
|
||||||
|
current_value = FIRST_VALUE
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def cached_function(self, user_id: str) -> str:
|
||||||
|
return self.current_value
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_module_cache_full_invalidation(self):
|
||||||
|
main_cache = TestCache()
|
||||||
|
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
|
||||||
|
|
||||||
|
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
|
||||||
|
worker_cache = TestCache()
|
||||||
|
worker_hs.get_module_api().register_cached_function(
|
||||||
|
worker_cache.cached_function
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||||
|
self.assertEqual(
|
||||||
|
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
|
||||||
|
)
|
||||||
|
|
||||||
|
main_cache.current_value = SECOND_VALUE
|
||||||
|
worker_cache.current_value = SECOND_VALUE
|
||||||
|
# No invalidation yet, should return the cached value on both the main process and the worker
|
||||||
|
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||||
|
self.assertEqual(
|
||||||
|
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Full invalidation on the main process, should be replicated on the worker that
|
||||||
|
# should returned the updated value too
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_module_api().invalidate_cache(
|
||||||
|
main_cache.cached_function, (KEY,)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
|
||||||
|
)
|
Loading…
Reference in New Issue