Merge pull request #2075 from matrix-org/erikj/cache_speed

Speed up cached function access
This commit is contained in:
Erik Johnston 2017-03-31 10:10:56 +01:00 committed by GitHub
commit 9cee0ce7db
7 changed files with 87 additions and 16 deletions

View File

@ -17,15 +17,12 @@ from twisted.internet import defer
from synapse.push.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_badge_count(store, user_id): def get_badge_count(store, user_id):
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ invites = yield store.get_invited_rooms_for_user(user_id)
preserve_fn(store.get_invited_rooms_for_user)(user_id), joins = yield store.get_rooms_for_user(user_id)
preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True))
my_receipts_by_room = yield store.get_receipts_for_user( my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read", user_id, "m.read",

View File

@ -89,6 +89,11 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback) deferred.addCallbacks(callback, errback)
def observe(self): def observe(self):
"""Observe the underlying deferred.
Can return either a deferred if the underlying deferred is still pending
(or has failed), or the actual value. Callers may need to use maybeDeferred.
"""
if not self._result: if not self._result:
d = defer.Deferred() d = defer.Deferred()
@ -101,7 +106,7 @@ class ObservableDeferred(object):
return d return d
else: else:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return res if success else defer.fail(res)
def observers(self): def observers(self):
return self._observers return self._observers

View File

@ -224,8 +224,20 @@ class _CacheDescriptorBase(object):
) )
self.num_args = num_args self.num_args = num_args
# list of the names of the args used as the cache key
self.arg_names = all_args[1:num_args + 1] self.arg_names = all_args[1:num_args + 1]
# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
self.arg_defaults = dict(zip(
all_args[-len(arg_spec.defaults):],
arg_spec.defaults
))
else:
self.arg_defaults = {}
if "cache_context" in self.arg_names: if "cache_context" in self.arg_names:
raise Exception( raise Exception(
"cache_context arg cannot be included among the cache keys" "cache_context arg cannot be included among the cache keys"
@ -289,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable, iterable=self.iterable,
) )
def get_cache_key(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.
We loop through each arg name, looking up if its in the `kwargs`,
otherwise using the next argument in `args`. If there are no more
args then we try looking the arg name up in the defaults
"""
pos = 0
for nm in self.arg_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield self.arg_defaults[nm]
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
# Add temp cache_context so inspect.getcallargs doesn't explode cache_key = tuple(get_cache_key(args, kwargs))
if self.add_cache_context:
kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
# Add our own `cache_context` to argument list if the wrapped function # Add our own `cache_context` to argument list if the wrapped function
# has asked for one # has asked for one
@ -341,7 +366,10 @@ class CacheDescriptor(_CacheDescriptorBase):
cache.set(cache_key, result_d, callback=invalidate_callback) cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe() observer = result_d.observe()
return logcontext.make_deferred_yieldable(observer) if isinstance(observer, defer.Deferred):
return logcontext.make_deferred_yieldable(observer)
else:
return observer
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all

View File

@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
events ([synapse.events.EventBase]): list of events to filter events ([synapse.events.EventBase]): list of events to filter
""" """
forgotten = yield preserve_context_over_deferred(defer.gatherResults([ forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)( defer.maybeDeferred(
preserve_fn(store.who_forgot_in_room),
room_id, room_id,
) )
for room_id in frozenset(e.room_id for e in events) for room_id in frozenset(e.room_id for e in events)

View File

@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
a.func.prefill(("foo",), ObservableDeferred(d)) a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals(a.func("foo").result, d.result) self.assertEquals(a.func("foo"), d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -175,3 +175,41 @@ class DescriptorTestCase(unittest.TestCase):
logcontext.LoggingContext.sentinel) logcontext.LoggingContext.sentinel)
return d1 return d1
@defer.inlineCallbacks
def test_cache_default_args(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2=2, arg3=3):
return self.mock(arg1, arg2, arg3)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2, 3)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()
# a call with same params shouldn't call the mock again
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_not_called()
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(2, 3, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()

View File

@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
# before the cache expires returns a resolved deferred. # before the cache expires returns a resolved deferred.
get_result_at_11 = self.cache.get(11, "key") get_result_at_11 = self.cache.get(11, "key")
self.assertIsNotNone(get_result_at_11) self.assertIsNotNone(get_result_at_11)
self.assertTrue(get_result_at_11.called) if isinstance(get_result_at_11, Deferred):
# The cache may return the actual result rather than a deferred
self.assertTrue(get_result_at_11.called)
# Check that getting the key after the deferred has resolved # Check that getting the key after the deferred has resolved
# after the cache expires returns None # after the cache expires returns None