Allow streaming cache invalidate all to workers. (#6749)

This commit is contained in:
Erik Johnston 2020-01-22 10:37:00 +00:00 committed by GitHub
parent 2093f83ea0
commit 5d7a6ad223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 69 additions and 15 deletions

1
changelog.d/6749.misc Normal file
View File

@ -0,0 +1 @@
Allow streaming cache 'invalidate all' to workers.

View File

@ -254,6 +254,11 @@ and they key to invalidate. For example:
> RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251]
Alternatively, an entire cache can be invalidated by sending down a `null`
instead of the key. For example:
> RDATA caches 550953772 ["get_user_by_id", null, 1550574873252]
However, there are times when a number of caches need to be invalidated However, there are times when a number of caches need to be invalidated
at the same time with the same key. To reduce traffic we batch those at the same time with the same key. To reduce traffic we batch those
invalidations into a single poke by defining a special cache name that invalidations into a single poke by defining a special cache name that

View File

@ -66,11 +66,16 @@ class BaseSlavedStore(SQLBaseStore):
self._cache_id_gen.advance(token) self._cache_id_gen.advance(token)
for row in rows: for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
room_id = row.keys[0] room_id = row.keys[0]
members_changed = set(row.keys[1:]) members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed) self._invalidate_state_caches(room_id, members_changed)
else: else:
self._attempt_to_invalidate_cache(row.cache_func, tuple(row.keys)) self._attempt_to_invalidate_cache(row.cache_func, row.keys)
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)

View File

@ -17,7 +17,9 @@
import itertools import itertools
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any from typing import Any, List, Optional
import attr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,10 +67,24 @@ PushersStreamRow = namedtuple(
"PushersStreamRow", "PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
) )
CachesStreamRow = namedtuple(
"CachesStreamRow",
("cache_func", "keys", "invalidation_ts"), # str # list(str) # int @attr.s
) class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
Attributes:
cache_func: Name of the cached function.
keys: The entry in the cache to invalidate. If None then will
invalidate all.
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
PublicRoomsStreamRow = namedtuple( PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow", "PublicRoomsStreamRow",
( (

View File

@ -17,6 +17,7 @@
import logging import logging
import random import random
from abc import ABCMeta from abc import ABCMeta
from typing import Any, Optional
from six import PY2 from six import PY2
from six.moves import builtins from six.moves import builtins
@ -26,7 +27,7 @@ from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.types import get_domain_from_id from synapse.types import Collection, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,17 +64,24 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(self, cache_name, key): def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
):
"""Attempts to invalidate the cache of the given name, ignoring if the """Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers, cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache. where they may not have the cache.
Args: Args:
cache_name (str) cache_name
key (tuple) key: Entry to invalidate. If None then invalidates the entire
cache.
""" """
try: try:
getattr(self, cache_name).invalidate(key) if key is None:
getattr(self, cache_name).invalidate_all()
else:
getattr(self, cache_name).invalidate(tuple(key))
except AttributeError: except AttributeError:
# We probably haven't pulled in the cache in this worker, # We probably haven't pulled in the cache in this worker,
# which is fine. # which is fine.

View File

@ -16,6 +16,7 @@
import itertools import itertools
import logging import logging
from typing import Any, Iterable, Optional
from twisted.internet import defer from twisted.internet import defer
@ -43,6 +44,14 @@ class CacheInvalidationStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func):
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
txn.call_after(cache_func.invalidate_all)
self._send_invalidation_to_replication(txn, cache_func.__name__, None)
def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
"""Special case invalidation of caches based on current state. """Special case invalidation of caches based on current state.
@ -73,17 +82,24 @@ class CacheInvalidationStore(SQLBaseStore):
txn, CURRENT_STATE_CACHE_NAME, [room_id] txn, CURRENT_STATE_CACHE_NAME, [room_id]
) )
def _send_invalidation_to_replication(self, txn, cache_name, keys): def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
"""Notifies replication that given cache has been invalidated. """Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally. Note that this does *not* invalidate the cache locally.
Args: Args:
txn txn
cache_name (str) cache_name
keys (iterable[str]) keys: Entry to invalidate. If None will invalidate all.
""" """
if cache_name == CURRENT_STATE_CACHE_NAME and keys is None:
raise Exception(
"Can't stream invalidate all with magic current state cache"
)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap # get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want # the transaction. However, we want to only get an ID when we want
@ -95,13 +111,16 @@ class CacheInvalidationStore(SQLBaseStore):
txn.call_after(ctx.__exit__, None, None, None) txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data) txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
keys = list(keys)
self.db.simple_insert_txn( self.db.simple_insert_txn(
txn, txn,
table="cache_invalidation_stream", table="cache_invalidation_stream",
values={ values={
"stream_id": stream_id, "stream_id": stream_id,
"cache_func": cache_name, "cache_func": cache_name,
"keys": list(keys), "keys": keys,
"invalidation_ts": self.clock.time_msec(), "invalidation_ts": self.clock.time_msec(),
}, },
) )