Allow for ignoring some arguments when caching. (#12189)
* `@cached` can now take an `uncached_args` which is an iterable of names to not use in the cache key. * Requires `@cached`, @cachedList` and `@lru_cache` to use keyword arguments for clarity. * Asserts that keyword-only arguments in cached functions are not accepted. (I tested this briefly and I don't believe this works properly.)
This commit is contained in:
parent
032688854b
commit
690cb4f3b3
|
@ -0,0 +1 @@
|
||||||
|
Support skipping some arguments when generating cache keys.
|
|
@ -1286,7 +1286,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
return {eid for ((_rid, eid), have_event) in res.items() if have_event}
|
return {eid for ((_rid, eid), have_event) in res.items() if have_event}
|
||||||
|
|
||||||
@cachedList("have_seen_event", "keys")
|
@cachedList(cached_method_name="have_seen_event", list_name="keys")
|
||||||
async def _have_seen_events_dict(
|
async def _have_seen_events_dict(
|
||||||
self, keys: Iterable[Tuple[str, str]]
|
self, keys: Iterable[Tuple[str, str]]
|
||||||
) -> Dict[Tuple[str, str], bool]:
|
) -> Dict[Tuple[str, str], bool]:
|
||||||
|
@ -1954,7 +1954,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
get_event_id_for_timestamp_txn,
|
get_event_id_for_timestamp_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedList("is_partial_state_event", list_name="event_ids")
|
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
|
||||||
async def get_partial_state_events(
|
async def get_partial_state_events(
|
||||||
self, event_ids: Collection[str]
|
self, event_ids: Collection[str]
|
||||||
) -> Dict[str, bool]:
|
) -> Dict[str, bool]:
|
||||||
|
|
|
@ -20,6 +20,7 @@ from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Hashable,
|
Hashable,
|
||||||
|
@ -69,6 +70,7 @@ class _CacheDescriptorBase:
|
||||||
self,
|
self,
|
||||||
orig: Callable[..., Any],
|
orig: Callable[..., Any],
|
||||||
num_args: Optional[int],
|
num_args: Optional[int],
|
||||||
|
uncached_args: Optional[Collection[str]] = None,
|
||||||
cache_context: bool = False,
|
cache_context: bool = False,
|
||||||
):
|
):
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
|
@ -76,6 +78,13 @@ class _CacheDescriptorBase:
|
||||||
arg_spec = inspect.getfullargspec(orig)
|
arg_spec = inspect.getfullargspec(orig)
|
||||||
all_args = arg_spec.args
|
all_args = arg_spec.args
|
||||||
|
|
||||||
|
# There's no reason that keyword-only arguments couldn't be supported,
|
||||||
|
# but right now they're buggy so do not allow them.
|
||||||
|
if arg_spec.kwonlyargs:
|
||||||
|
raise ValueError(
|
||||||
|
"_CacheDescriptorBase does not support keyword-only arguments."
|
||||||
|
)
|
||||||
|
|
||||||
if "cache_context" in all_args:
|
if "cache_context" in all_args:
|
||||||
if not cache_context:
|
if not cache_context:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -88,6 +97,9 @@ class _CacheDescriptorBase:
|
||||||
" named `cache_context`"
|
" named `cache_context`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_args is not None and uncached_args is not None:
|
||||||
|
raise ValueError("Cannot provide both num_args and uncached_args")
|
||||||
|
|
||||||
if num_args is None:
|
if num_args is None:
|
||||||
num_args = len(all_args) - 1
|
num_args = len(all_args) - 1
|
||||||
if cache_context:
|
if cache_context:
|
||||||
|
@ -105,6 +117,12 @@ class _CacheDescriptorBase:
|
||||||
# list of the names of the args used as the cache key
|
# list of the names of the args used as the cache key
|
||||||
self.arg_names = all_args[1 : num_args + 1]
|
self.arg_names = all_args[1 : num_args + 1]
|
||||||
|
|
||||||
|
# If there are args to not cache on, filter them out (and fix the size of num_args).
|
||||||
|
if uncached_args is not None:
|
||||||
|
include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names]
|
||||||
|
else:
|
||||||
|
include_arg_in_cache_key = [True] * len(self.arg_names)
|
||||||
|
|
||||||
# self.arg_defaults is a map of arg name to its default value for each
|
# self.arg_defaults is a map of arg name to its default value for each
|
||||||
# argument that has a default value
|
# argument that has a default value
|
||||||
if arg_spec.defaults:
|
if arg_spec.defaults:
|
||||||
|
@ -119,8 +137,8 @@ class _CacheDescriptorBase:
|
||||||
|
|
||||||
self.add_cache_context = cache_context
|
self.add_cache_context = cache_context
|
||||||
|
|
||||||
self.cache_key_builder = get_cache_key_builder(
|
self.cache_key_builder = _get_cache_key_builder(
|
||||||
self.arg_names, self.arg_defaults
|
self.arg_names, include_arg_in_cache_key, self.arg_defaults
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]):
|
||||||
|
|
||||||
|
|
||||||
def lru_cache(
|
def lru_cache(
|
||||||
max_entries: int = 1000,
|
*, max_entries: int = 1000, cache_context: bool = False
|
||||||
cache_context: bool = False,
|
|
||||||
) -> Callable[[F], _LruCachedFunction[F]]:
|
) -> Callable[[F], _LruCachedFunction[F]]:
|
||||||
"""A method decorator that applies a memoizing cache around the function.
|
"""A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
|
@ -186,7 +203,9 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
||||||
max_entries: int = 1000,
|
max_entries: int = 1000,
|
||||||
cache_context: bool = False,
|
cache_context: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(orig, num_args=None, cache_context=cache_context)
|
super().__init__(
|
||||||
|
orig, num_args=None, uncached_args=None, cache_context=cache_context
|
||||||
|
)
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
|
|
||||||
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
||||||
|
@ -260,6 +279,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
num_args: number of positional arguments (excluding ``self`` and
|
num_args: number of positional arguments (excluding ``self`` and
|
||||||
``cache_context``) to use as cache keys. Defaults to all named
|
``cache_context``) to use as cache keys. Defaults to all named
|
||||||
args of the function.
|
args of the function.
|
||||||
|
uncached_args: a list of argument names to not use as the cache key.
|
||||||
|
(``self`` and ``cache_context`` are always ignored.) Cannot be used
|
||||||
|
with num_args.
|
||||||
tree:
|
tree:
|
||||||
cache_context:
|
cache_context:
|
||||||
iterable:
|
iterable:
|
||||||
|
@ -273,12 +295,18 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
orig: Callable[..., Any],
|
orig: Callable[..., Any],
|
||||||
max_entries: int = 1000,
|
max_entries: int = 1000,
|
||||||
num_args: Optional[int] = None,
|
num_args: Optional[int] = None,
|
||||||
|
uncached_args: Optional[Collection[str]] = None,
|
||||||
tree: bool = False,
|
tree: bool = False,
|
||||||
cache_context: bool = False,
|
cache_context: bool = False,
|
||||||
iterable: bool = False,
|
iterable: bool = False,
|
||||||
prune_unread_entries: bool = True,
|
prune_unread_entries: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
super().__init__(
|
||||||
|
orig,
|
||||||
|
num_args=num_args,
|
||||||
|
uncached_args=uncached_args,
|
||||||
|
cache_context=cache_context,
|
||||||
|
)
|
||||||
|
|
||||||
if tree and self.num_args < 2:
|
if tree and self.num_args < 2:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -369,7 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
but including list_name) to use as cache keys. Defaults to all
|
but including list_name) to use as cache keys. Defaults to all
|
||||||
named args of the function.
|
named args of the function.
|
||||||
"""
|
"""
|
||||||
super().__init__(orig, num_args=num_args)
|
super().__init__(orig, num_args=num_args, uncached_args=None)
|
||||||
|
|
||||||
self.list_name = list_name
|
self.list_name = list_name
|
||||||
|
|
||||||
|
@ -530,8 +558,10 @@ class _CacheContext:
|
||||||
|
|
||||||
|
|
||||||
def cached(
|
def cached(
|
||||||
|
*,
|
||||||
max_entries: int = 1000,
|
max_entries: int = 1000,
|
||||||
num_args: Optional[int] = None,
|
num_args: Optional[int] = None,
|
||||||
|
uncached_args: Optional[Collection[str]] = None,
|
||||||
tree: bool = False,
|
tree: bool = False,
|
||||||
cache_context: bool = False,
|
cache_context: bool = False,
|
||||||
iterable: bool = False,
|
iterable: bool = False,
|
||||||
|
@ -541,6 +571,7 @@ def cached(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
num_args=num_args,
|
num_args=num_args,
|
||||||
|
uncached_args=uncached_args,
|
||||||
tree=tree,
|
tree=tree,
|
||||||
cache_context=cache_context,
|
cache_context=cache_context,
|
||||||
iterable=iterable,
|
iterable=iterable,
|
||||||
|
@ -551,7 +582,7 @@ def cached(
|
||||||
|
|
||||||
|
|
||||||
def cachedList(
|
def cachedList(
|
||||||
cached_method_name: str, list_name: str, num_args: Optional[int] = None
|
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
|
||||||
) -> Callable[[F], _CachedFunction[F]]:
|
) -> Callable[[F], _CachedFunction[F]]:
|
||||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||||
|
|
||||||
|
@ -590,13 +621,16 @@ def cachedList(
|
||||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_key_builder(
|
def _get_cache_key_builder(
|
||||||
param_names: Sequence[str], param_defaults: Mapping[str, Any]
|
param_names: Sequence[str],
|
||||||
|
include_params: Sequence[bool],
|
||||||
|
param_defaults: Mapping[str, Any],
|
||||||
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
|
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
|
||||||
"""Construct a function which will build cache keys suitable for a cached function
|
"""Construct a function which will build cache keys suitable for a cached function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
param_names: list of formal parameter names for the cached function
|
param_names: list of formal parameter names for the cached function
|
||||||
|
include_params: list of bools of whether to include the parameter name in the cache key
|
||||||
param_defaults: a mapping from parameter name to default value for that param
|
param_defaults: a mapping from parameter name to default value for that param
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -608,6 +642,7 @@ def get_cache_key_builder(
|
||||||
|
|
||||||
if len(param_names) == 1:
|
if len(param_names) == 1:
|
||||||
nm = param_names[0]
|
nm = param_names[0]
|
||||||
|
assert include_params[0] is True
|
||||||
|
|
||||||
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||||
if nm in kwargs:
|
if nm in kwargs:
|
||||||
|
@ -620,13 +655,18 @@ def get_cache_key_builder(
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||||
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
|
return tuple(
|
||||||
|
_get_cache_key_gen(
|
||||||
|
param_names, include_params, param_defaults, args, kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return get_cache_key
|
return get_cache_key
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_key_gen(
|
def _get_cache_key_gen(
|
||||||
param_names: Iterable[str],
|
param_names: Iterable[str],
|
||||||
|
include_params: Iterable[bool],
|
||||||
param_defaults: Mapping[str, Any],
|
param_defaults: Mapping[str, Any],
|
||||||
args: Sequence[Any],
|
args: Sequence[Any],
|
||||||
kwargs: Mapping[str, Any],
|
kwargs: Mapping[str, Any],
|
||||||
|
@ -637,16 +677,18 @@ def _get_cache_key_gen(
|
||||||
This is essentially the same operation as `inspect.getcallargs`, but optimised so
|
This is essentially the same operation as `inspect.getcallargs`, but optimised so
|
||||||
that we don't need to inspect the target function for each call.
|
that we don't need to inspect the target function for each call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We loop through each arg name, looking up if its in the `kwargs`,
|
# We loop through each arg name, looking up if its in the `kwargs`,
|
||||||
# otherwise using the next argument in `args`. If there are no more
|
# otherwise using the next argument in `args`. If there are no more
|
||||||
# args then we try looking the arg name up in the defaults.
|
# args then we try looking the arg name up in the defaults.
|
||||||
pos = 0
|
pos = 0
|
||||||
for nm in param_names:
|
for nm, inc in zip(param_names, include_params):
|
||||||
if nm in kwargs:
|
if nm in kwargs:
|
||||||
|
if inc:
|
||||||
yield kwargs[nm]
|
yield kwargs[nm]
|
||||||
elif pos < len(args):
|
elif pos < len(args):
|
||||||
|
if inc:
|
||||||
yield args[pos]
|
yield args[pos]
|
||||||
pos += 1
|
pos += 1
|
||||||
else:
|
else:
|
||||||
|
if inc:
|
||||||
yield param_defaults[nm]
|
yield param_defaults[nm]
|
||||||
|
|
|
@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
self.assertEqual(r, "chips")
|
self.assertEqual(r, "chips")
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache_uncached_args(self):
|
||||||
|
"""
|
||||||
|
Only the arguments not named in uncached_args should matter to the cache
|
||||||
|
|
||||||
|
Note that this is identical to test_cache_num_args, but provides the
|
||||||
|
arguments differently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
# Note that it is important that this is not the last argument to
|
||||||
|
# test behaviour of skipping arguments properly.
|
||||||
|
@descriptors.cached(uncached_args=("arg2",))
|
||||||
|
def fn(self, arg1, arg2, arg3):
|
||||||
|
return self.mock(arg1, arg2, arg3)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = "fish"
|
||||||
|
r = yield obj.fn(1, 2, 3)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
obj.mock.assert_called_once_with(1, 2, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = "chips"
|
||||||
|
r = yield obj.fn(2, 3, 4)
|
||||||
|
self.assertEqual(r, "chips")
|
||||||
|
obj.mock.assert_called_once_with(2, 3, 4)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached; we should be able to vary
|
||||||
|
# the second argument and still get the cached result.
|
||||||
|
r = yield obj.fn(1, 4, 3)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
r = yield obj.fn(2, 5, 4)
|
||||||
|
self.assertEqual(r, "chips")
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache_kwargs(self):
|
||||||
|
"""Test that keyword arguments are treated properly"""
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, kwarg1=2):
|
||||||
|
return self.mock(arg1, kwarg1=kwarg1)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = "fish"
|
||||||
|
r = yield obj.fn(1, kwarg1=2)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
obj.mock.assert_called_once_with(1, kwarg1=2)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = "chips"
|
||||||
|
r = yield obj.fn(1, kwarg1=3)
|
||||||
|
self.assertEqual(r, "chips")
|
||||||
|
obj.mock.assert_called_once_with(1, kwarg1=3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the values should now be cached.
|
||||||
|
r = yield obj.fn(1, kwarg1=2)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
# We should be able to not provide kwarg1 and get the cached value back.
|
||||||
|
r = yield obj.fn(1)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
# Keyword arguments can be in any order.
|
||||||
|
r = yield obj.fn(kwarg1=2, arg1=1)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
def test_cache_with_sync_exception(self):
|
def test_cache_with_sync_exception(self):
|
||||||
"""If the wrapped function throws synchronously, things should continue to work"""
|
"""If the wrapped function throws synchronously, things should continue to work"""
|
||||||
|
|
||||||
|
@ -656,7 +734,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
def fn(self, arg1, arg2):
|
def fn(self, arg1, arg2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@descriptors.cachedList("fn", "args1")
|
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
||||||
async def list_fn(self, args1, arg2):
|
async def list_fn(self, args1, arg2):
|
||||||
assert current_context().name == "c1"
|
assert current_context().name == "c1"
|
||||||
# we want this to behave like an asynchronous function
|
# we want this to behave like an asynchronous function
|
||||||
|
@ -715,7 +793,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
def fn(self, arg1):
|
def fn(self, arg1):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@descriptors.cachedList("fn", "args1")
|
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
||||||
def list_fn(self, args1) -> "Deferred[dict]":
|
def list_fn(self, args1) -> "Deferred[dict]":
|
||||||
return self.mock(args1)
|
return self.mock(args1)
|
||||||
|
|
||||||
|
@ -758,7 +836,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
def fn(self, arg1, arg2):
|
def fn(self, arg1, arg2):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@descriptors.cachedList("fn", "args1")
|
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
||||||
async def list_fn(self, args1, arg2):
|
async def list_fn(self, args1, arg2):
|
||||||
# we want this to behave like an asynchronous function
|
# we want this to behave like an asynchronous function
|
||||||
await run_on_reactor()
|
await run_on_reactor()
|
||||||
|
|
Loading…
Reference in New Issue