diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 247dd15694..78334a98cf 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -156,14 +156,14 @@ class PushRuleStore(SQLBaseStore): # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, cache_context=cache_context, + local_users_in_room, on_invalidate=cache_context.invalidate, ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher ) users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, cache_context=cache_context, + room_id, on_invalidate=cache_context.invalidate, ) # any users with pushers must be ours: they have pushers @@ -172,7 +172,7 @@ class PushRuleStore(SQLBaseStore): user_ids.add(uid) rules_by_user = yield self.bulk_get_push_rules( - user_ids, cache_context=cache_context + user_ids, on_invalidate=cache_context.invalidate, ) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index e7a74d3da8..e93ff40dc0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -148,8 +148,8 @@ class CacheDescriptor(object): @cachedInlineCallbacks(cache_context=True) def foo(self, key, cache_context): - r1 = yield self.bar1(key, cache_context=cache_context) - r2 = yield self.bar2(key, cache_context=cache_context) + r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) defer.returnValue(r1 + r2) """ @@ -208,11 +208,7 @@ class CacheDescriptor(object): def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated - cache_context = kwargs.pop("cache_context", None) - if cache_context: - context_callback = cache_context.invalidate - else: - context_callback = None + invalidate_callback = kwargs.pop("on_invalidate", None) # Add our own `cache_context` to argument list if the wrapped function # has asked for one @@ -226,7 +222,7 @@ class CacheDescriptor(object): self_context.key = cache_key try: - cached_result_d = cache.get(cache_key, callback=context_callback) + cached_result_d = cache.get(cache_key, callback=invalidate_callback) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -263,7 +259,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=context_callback) + cache.update(sequence, cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -332,11 +328,9 @@ class CacheListDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): - cache_context = kwargs.pop("cache_context", None) - if cache_context: - context_callback = cache_context.invalidate - else: - context_callback = None + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] @@ -352,7 +346,7 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = cache.get(tuple(key), callback=context_callback) + res = cache.get(tuple(key), callback=invalidate_callback) if not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) @@ -388,7 +382,7 @@ class CacheListDescriptor(object): key[self.list_pos] = arg cache.update( sequence, tuple(key), observer, - callback=context_callback + callback=invalidate_callback ) def invalidate(f, key): diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index eab0c8d219..4fc3639de0 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -214,7 +214,7 @@ class CacheDecoratorTestCase(unittest.TestCase): @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 - return self.func(key, cache_context=cache_context) + return self.func(key, on_invalidate=cache_context.invalidate) a = A() yield a.func2("foo") @@ -247,7 +247,7 @@ class CacheDecoratorTestCase(unittest.TestCase): @cached(cache_context=True) def func2(self, key, cache_context): callcount2[0] += 1 - return self.func(key, cache_context=cache_context) + return self.func(key, on_invalidate=cache_context.invalidate) a = A() yield a.func2("foo")