Merge pull request #2075 from matrix-org/erikj/cache_speed
Speed up cached function access
This commit is contained in:
commit
9cee0ce7db
|
@ -17,15 +17,12 @@ from twisted.internet import defer
|
||||||
from synapse.push.presentable_names import (
|
from synapse.push.presentable_names import (
|
||||||
calculate_room_name, name_from_member_event
|
calculate_room_name, name_from_member_event
|
||||||
)
|
)
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_badge_count(store, user_id):
|
def get_badge_count(store, user_id):
|
||||||
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
|
invites = yield store.get_invited_rooms_for_user(user_id)
|
||||||
preserve_fn(store.get_invited_rooms_for_user)(user_id),
|
joins = yield store.get_rooms_for_user(user_id)
|
||||||
preserve_fn(store.get_rooms_for_user)(user_id),
|
|
||||||
], consumeErrors=True))
|
|
||||||
|
|
||||||
my_receipts_by_room = yield store.get_receipts_for_user(
|
my_receipts_by_room = yield store.get_receipts_for_user(
|
||||||
user_id, "m.read",
|
user_id, "m.read",
|
||||||
|
|
|
@ -89,6 +89,11 @@ class ObservableDeferred(object):
|
||||||
deferred.addCallbacks(callback, errback)
|
deferred.addCallbacks(callback, errback)
|
||||||
|
|
||||||
def observe(self):
|
def observe(self):
|
||||||
|
"""Observe the underlying deferred.
|
||||||
|
|
||||||
|
Can return either a deferred if the underlying deferred is still pending
|
||||||
|
(or has failed), or the actual value. Callers may need to use maybeDeferred.
|
||||||
|
"""
|
||||||
if not self._result:
|
if not self._result:
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
|
|
||||||
|
@ -101,7 +106,7 @@ class ObservableDeferred(object):
|
||||||
return d
|
return d
|
||||||
else:
|
else:
|
||||||
success, res = self._result
|
success, res = self._result
|
||||||
return defer.succeed(res) if success else defer.fail(res)
|
return res if success else defer.fail(res)
|
||||||
|
|
||||||
def observers(self):
|
def observers(self):
|
||||||
return self._observers
|
return self._observers
|
||||||
|
|
|
@ -224,8 +224,20 @@ class _CacheDescriptorBase(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_args = num_args
|
self.num_args = num_args
|
||||||
|
|
||||||
|
# list of the names of the args used as the cache key
|
||||||
self.arg_names = all_args[1:num_args + 1]
|
self.arg_names = all_args[1:num_args + 1]
|
||||||
|
|
||||||
|
# self.arg_defaults is a map of arg name to its default value for each
|
||||||
|
# argument that has a default value
|
||||||
|
if arg_spec.defaults:
|
||||||
|
self.arg_defaults = dict(zip(
|
||||||
|
all_args[-len(arg_spec.defaults):],
|
||||||
|
arg_spec.defaults
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.arg_defaults = {}
|
||||||
|
|
||||||
if "cache_context" in self.arg_names:
|
if "cache_context" in self.arg_names:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"cache_context arg cannot be included among the cache keys"
|
"cache_context arg cannot be included among the cache keys"
|
||||||
|
@ -289,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_cache_key(args, kwargs):
|
||||||
|
"""Given some args/kwargs return a generator that resolves into
|
||||||
|
the cache_key.
|
||||||
|
|
||||||
|
We loop through each arg name, looking up if its in the `kwargs`,
|
||||||
|
otherwise using the next argument in `args`. If there are no more
|
||||||
|
args then we try looking the arg name up in the defaults
|
||||||
|
"""
|
||||||
|
pos = 0
|
||||||
|
for nm in self.arg_names:
|
||||||
|
if nm in kwargs:
|
||||||
|
yield kwargs[nm]
|
||||||
|
elif pos < len(args):
|
||||||
|
yield args[pos]
|
||||||
|
pos += 1
|
||||||
|
else:
|
||||||
|
yield self.arg_defaults[nm]
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||||
# whenever we are invalidated
|
# whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
# Add temp cache_context so inspect.getcallargs doesn't explode
|
cache_key = tuple(get_cache_key(args, kwargs))
|
||||||
if self.add_cache_context:
|
|
||||||
kwargs["cache_context"] = None
|
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
|
||||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
|
||||||
|
|
||||||
# Add our own `cache_context` to argument list if the wrapped function
|
# Add our own `cache_context` to argument list if the wrapped function
|
||||||
# has asked for one
|
# has asked for one
|
||||||
|
@ -341,7 +366,10 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
cache.set(cache_key, result_d, callback=invalidate_callback)
|
cache.set(cache_key, result_d, callback=invalidate_callback)
|
||||||
observer = result_d.observe()
|
observer = result_d.observe()
|
||||||
|
|
||||||
return logcontext.make_deferred_yieldable(observer)
|
if isinstance(observer, defer.Deferred):
|
||||||
|
return logcontext.make_deferred_yieldable(observer)
|
||||||
|
else:
|
||||||
|
return observer
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
wrapped.invalidate_all = cache.invalidate_all
|
||||||
|
|
|
@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
||||||
events ([synapse.events.EventBase]): list of events to filter
|
events ([synapse.events.EventBase]): list of events to filter
|
||||||
"""
|
"""
|
||||||
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(store.who_forgot_in_room)(
|
defer.maybeDeferred(
|
||||||
|
preserve_fn(store.who_forgot_in_room),
|
||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
for room_id in frozenset(e.room_id for e in events)
|
for room_id in frozenset(e.room_id for e in events)
|
||||||
|
|
|
@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
|
||||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||||
|
|
||||||
self.assertEquals(a.func("foo").result, d.result)
|
self.assertEquals(a.func("foo"), d.result)
|
||||||
self.assertEquals(callcount[0], 0)
|
self.assertEquals(callcount[0], 0)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -175,3 +175,41 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
logcontext.LoggingContext.sentinel)
|
logcontext.LoggingContext.sentinel)
|
||||||
|
|
||||||
return d1
|
return d1
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache_default_args(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2=2, arg3=3):
|
||||||
|
return self.mock(arg1, arg2, arg3)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
obj.mock.return_value = 'fish'
|
||||||
|
r = yield obj.fn(1, 2, 3)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_called_once_with(1, 2, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with same params shouldn't call the mock again
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = 'chips'
|
||||||
|
r = yield obj.fn(2, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_called_once_with(2, 3, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(2, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
|
@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
|
||||||
# before the cache expires returns a resolved deferred.
|
# before the cache expires returns a resolved deferred.
|
||||||
get_result_at_11 = self.cache.get(11, "key")
|
get_result_at_11 = self.cache.get(11, "key")
|
||||||
self.assertIsNotNone(get_result_at_11)
|
self.assertIsNotNone(get_result_at_11)
|
||||||
self.assertTrue(get_result_at_11.called)
|
if isinstance(get_result_at_11, Deferred):
|
||||||
|
# The cache may return the actual result rather than a deferred
|
||||||
|
self.assertTrue(get_result_at_11.called)
|
||||||
|
|
||||||
# Check that getting the key after the deferred has resolved
|
# Check that getting the key after the deferred has resolved
|
||||||
# after the cache expires returns None
|
# after the cache expires returns None
|
||||||
|
|
Loading…
Reference in New Issue