Speed up `@cachedList` (#13591)

This speeds things up by ~2x.

The vast majority of the time is now spent in `LruCache` moving things around the linked lists.

We do this via two things:
1. Don't create a deferred per-key during bulk set operations in `DeferredCache`. Instead, only create them if a subsequent caller asks for the key.
2. Add a bulk lookup API to `DeferredCache` rather than use a loop.
This commit is contained in:
Erik Johnston 2022-08-23 15:53:27 +01:00 committed by GitHub
parent 05c9c7363b
commit f7ddfe17a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 298 additions and 141 deletions

1
changelog.d/13591.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of `@cachedList`.

View File

@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import enum import enum
import threading import threading
from typing import ( from typing import (
Callable, Callable,
Collection,
Dict,
Generic, Generic,
Iterable,
MutableMapping, MutableMapping,
Optional, Optional,
Set,
Sized, Sized,
Tuple,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object. # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache: Union[ self._pending_deferred_cache: Union[
TreeCache, "MutableMapping[KT, CacheEntry]" TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
] = cache_type() ] = cache_type()
def metrics_cb() -> None: def metrics_cb() -> None:
@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises: Raises:
KeyError if the key is not found in the cache KeyError if the key is not found in the cache
""" """
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel: if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks) val.add_invalidation_callback(key, callback)
if update_metrics: if update_metrics:
m = self.cache.metrics m = self.cache.metrics
assert m # we always have a name, so should always have metrics assert m # we always have a name, so should always have metrics
m.inc_hits() m.inc_hits()
return val.deferred.observe() return val.deferred(key)
callbacks = (callback,) if callback else ()
val2 = self.cache.get( val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else: else:
return defer.succeed(val2) return defer.succeed(val2)
def get_bulk(
self,
keys: Collection[KT],
callback: Optional[Callable[[], None]] = None,
) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
"""Bulk lookup of items in the cache.
Returns:
A 3-tuple of:
1. a dict of key/value of items already cached;
2. a deferred that resolves to a dict of key/value of items
we're already fetching; and
3. a collection of keys that don't appear in the previous two.
"""
# The cached results
cached = {}
# List of pending deferreds
pending = []
# Dict that gets filled out when the pending deferreds complete
pending_results = {}
# List of keys that aren't in either cache
missing = []
callbacks = (callback,) if callback else ()
for key in keys:
# Check if its in the main cache.
immediate_value = self.cache.get(
key,
_Sentinel.sentinel,
callbacks=callbacks,
)
if immediate_value is not _Sentinel.sentinel:
cached[key] = immediate_value
continue
# Check if its in the pending cache
pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if pending_value is not _Sentinel.sentinel:
pending_value.add_invalidation_callback(key, callback)
def completed_cb(value: VT, key: KT) -> VT:
pending_results[key] = value
return value
# Add a callback to fill out `pending_results` when that completes
d = pending_value.deferred(key).addCallback(completed_cb, key)
pending.append(d)
continue
# Not in either cache
missing.append(key)
# If we've got pending deferreds, squash them into a single one that
# returns `pending_results`.
pending_deferred = None
if pending:
pending_deferred = defer.gatherResults(
pending, consumeErrors=True
).addCallback(lambda _: pending_results)
return (cached, pending_deferred, missing)
def get_immediate( def get_immediate(
self, key: KT, default: T, update_metrics: bool = True self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]: ) -> Union[VT, T]:
@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value: a deferred which will complete with a result to add to the cache value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated callback: An optional callback to be called when the entry is invalidated
""" """
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread() self.check_thread()
existing_entry = self._pending_deferred_cache.pop(key, None) self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
# XXX: why don't we invalidate the entry in `self.cache` yet? # XXX: why don't we invalidate the entry in `self.cache` yet?
# we can save a whole load of effort if the deferred is ready.
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
self.cache.set(key, cast(VT, result), callbacks)
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now, # otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later. # and add callbacks to add it to the cache properly later.
entry = CacheEntrySingle[KT, VT](value)
observable = ObservableDeferred(value, consumeErrors=True) entry.add_invalidation_callback(key, callback)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
deferred = entry.deferred(key).addCallbacks(
def compare_and_pop() -> bool: self._completed_callback,
"""Check if our entry is still the one in _pending_deferred_cache, and self._error_callback,
if so, pop it. callbackArgs=(entry, key),
errbackArgs=(entry, key),
Returns true if the entries matched. )
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result: VT) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail: Failure) -> None:
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
# we return a new Deferred which will be called before any subsequent observers. # we return a new Deferred which will be called before any subsequent observers.
return observable.observe() return deferred
def start_bulk_input(
self,
keys: Collection[KT],
callback: Optional[Callable[[], None]] = None,
) -> "CacheMultipleEntries[KT, VT]":
"""Bulk set API for use when fetching multiple keys at once from the DB.
Called *before* starting the fetch from the DB, and the caller *must*
call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
"""
entry = CacheMultipleEntries[KT, VT]()
entry.add_global_invalidation_callback(callback)
for key in keys:
self._pending_deferred_cache[key] = entry
return entry
def _completed_callback(
self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
) -> VT:
"""Called when a deferred is completed."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return value
self.cache.set(key, value, entry.get_invalidation_callbacks(key))
return value
def _error_callback(
self,
failure: Failure,
entry: "CacheEntry[KT, VT]",
key: KT,
) -> Failure:
"""Called when a deferred errors."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return failure
for cb in entry.get_invalidation_callbacks(key):
cb()
return failure
def prefill( def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None: ) -> None:
callbacks = [callback] if callback else [] callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks) self.cache.set(key, value, callbacks=callbacks)
self._pending_deferred_cache.pop(key, None)
def invalidate(self, key: KT) -> None: def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries """Delete a key, or tree of entries
@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self.cache.del_multi(key) self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the # if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned # _pending_deferred_cache, which will (a) stop it being returned for
# for future queries and (b) stop it being persisted as a proper entry # future queries and (b) stop it being persisted as a proper entry
# in self.cache. # in self.cache.
entry = self._pending_deferred_cache.pop(key, None) entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry: if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling # case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing. # iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry(entry): for entry in iterate_tree_cache_entry(entry):
entry.invalidate() for cb in entry.get_invalidation_callbacks(key):
cb()
def invalidate_all(self) -> None: def invalidate_all(self) -> None:
self.check_thread() self.check_thread()
self.cache.clear() self.cache.clear()
for entry in self._pending_deferred_cache.values(): for key, entry in self._pending_deferred_cache.items():
entry.invalidate() for cb in entry.get_invalidation_callbacks(key):
cb()
self._pending_deferred_cache.clear() self._pending_deferred_cache.clear()
class CacheEntry: class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
__slots__ = ["deferred", "callbacks", "invalidated"] """Abstract class for entries in `DeferredCache[KT, VT]`"""
def __init__( @abc.abstractmethod
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] def deferred(self, key: KT) -> "defer.Deferred[VT]":
): """Get a deferred that a caller can wait on to get the value at the
self.deferred = deferred given key"""
self.callbacks = set(callbacks) ...
self.invalidated = False
def invalidate(self) -> None: @abc.abstractmethod
if not self.invalidated: def add_invalidation_callback(
self.invalidated = True self, key: KT, callback: Optional[Callable[[], None]]
for callback in self.callbacks: ) -> None:
callback() """Add an invalidation callback"""
self.callbacks.clear() ...
@abc.abstractmethod
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
"""Get all invalidation callbacks"""
...
class CacheEntrySingle(CacheEntry[KT, VT]):
"""An implementation of `CacheEntry` wrapping a deferred that results in a
single cache entry.
"""
__slots__ = ["_deferred", "_callbacks"]
def __init__(self, deferred: "defer.Deferred[VT]") -> None:
self._deferred = ObservableDeferred(deferred, consumeErrors=True)
self._callbacks: Set[Callable[[], None]] = set()
def deferred(self, key: KT) -> "defer.Deferred[VT]":
return self._deferred.observe()
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
if callback is None:
return
self._callbacks.add(callback)
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks
class CacheMultipleEntries(CacheEntry[KT, VT]):
"""Cache entry that is used for bulk lookups and insertions."""
__slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
def __init__(self) -> None:
self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
self._global_callbacks: Set[Callable[[], None]] = set()
def deferred(self, key: KT) -> "defer.Deferred[VT]":
if not self._deferred:
self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
return self._deferred.observe().addCallback(lambda res: res.get(key))
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
if callback is None:
return
self._callbacks.setdefault(key, set()).add(callback)
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks.get(key, set()) | self._global_callbacks
def add_global_invalidation_callback(
self, callback: Optional[Callable[[], None]]
) -> None:
"""Add a callback for when any keys get invalidated."""
if callback is None:
return
self._global_callbacks.add(callback)
def complete_bulk(
self,
cache: DeferredCache[KT, VT],
result: Dict[KT, VT],
) -> None:
"""Called when there is a result"""
for key, value in result.items():
cache._completed_callback(value, self, key)
if self._deferred:
self._deferred.callback(result)
def error_bulk(
self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
) -> None:
"""Called when bulk lookup failed."""
for key in keys:
cache._error_callback(failure, self, key)
if self._deferred:
self._deferred.errback(failure)

View File

@ -25,6 +25,7 @@ from typing import (
Generic, Generic,
Hashable, Hashable,
Iterable, Iterable,
List,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
@ -440,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
results = {}
def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res
# list of deferreds to wait for
cached_defers = []
missing = set()
# If the cache takes a single arg then that is used as the key, # If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used. # otherwise a tuple is used.
if num_args == 1: if num_args == 1:
@ -457,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def arg_to_cache_key(arg: Hashable) -> Hashable: def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg return arg
def cache_key_to_arg(key: tuple) -> Hashable:
return key
else: else:
keylist = list(keyargs) keylist = list(keyargs)
@ -464,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keylist[self.list_pos] = arg keylist[self.list_pos] = arg
return tuple(keylist) return tuple(keylist)
for arg in list_args: def cache_key_to_arg(key: tuple) -> Hashable:
try: return key[self.list_pos]
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not res.called: cache_keys = [arg_to_cache_key(arg) for arg in list_args]
res.addCallback(update_results_dict, arg) immediate_results, pending_deferred, missing = cache.get_bulk(
cached_defers.append(res) cache_keys, callback=invalidate_callback
else: )
results[arg] = res.result
except KeyError: results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
missing.add(arg)
cached_defers: List["defer.Deferred[Any]"] = []
if pending_deferred:
def update_results(r: Dict) -> None:
for k, v in r.items():
results[cache_key_to_arg(k)] = v
pending_deferred.addCallback(update_results)
cached_defers.append(pending_deferred)
if missing: if missing:
# we need a deferred for each entry in the list, cache_entry = cache.start_bulk_input(missing, invalidate_callback)
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred: "defer.Deferred[Any]" = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
cached_defers.append(
cache.set(key, deferred, callback=invalidate_callback)
)
def complete_all(res: Dict[Hashable, Any]) -> None: def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a dict. missing_results = {}
# We can now update our own result map, and then resolve the for key in missing:
# observable deferreds in the cache. arg = cache_key_to_arg(key)
for e, d1 in deferreds_map.items(): val = res.get(arg, None)
val = res.get(e, None)
# make sure we update the results map before running the results[arg] = val
# deferreds, because as soon as we run the last deferred, the missing_results[key] = val
# gatherResults() below will complete and return the result
# dict to our caller. cache_entry.complete_bulk(cache, missing_results)
results[e] = val
d1.callback(val)
def errback_all(f: Failure) -> None: def errback_all(f: Failure) -> None:
# the wrapped function has failed. Propagate the failure into cache_entry.error_bulk(cache, missing, f)
# the cache, which will invalidate the entry, and cause the
# relevant cached_deferreds to fail, which will propagate the
# failure to our caller.
for d1 in deferreds_map.values():
d1.errback(f)
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing args_to_call[self.list_name] = {
cache_key_to_arg(key) for key in missing
}
# dispatch the call, and attach the two handlers # dispatch the call, and attach the two handlers
defer.maybeDeferred( missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all) ).addCallbacks(complete_all, errback_all)
cached_defers.append(missing_d)
if cached_defers: if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(

View File

@ -135,6 +135,9 @@ class TreeCache:
def values(self): def values(self):
return iterate_tree_cache_entry(self.root) return iterate_tree_cache_entry(self.root)
def items(self):
return iterate_tree_cache_items((), self.root)
def __len__(self) -> int: def __len__(self) -> int:
return self.size return self.size