Merge pull request #695 from matrix-org/markjh/cachesII

Make the cache objects be per instance rather than being global
This commit is contained in:
Mark Haines 2016-04-06 13:21:19 +01:00
commit 196ebaf662
4 changed files with 29 additions and 26 deletions

View File

@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore):
"content": content, "content": content,
}]) }])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", @cachedList(cached_method_name="get_linearized_receipts_for_room",
num_args=3, inlineCallbacks=True) list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids: if not room_ids:
defer.returnValue({}) defer.returnValue({})

View File

@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1,
inlineCallbacks=True) inlineCallbacks=True)
def are_guests(self, user_ids): def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (

View File

@ -273,8 +273,8 @@ class StateStore(SQLBaseStore):
desc="_get_state_group_for_event", desc="_get_state_group_for_event",
) )
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", @cachedList(cached_method_name="_get_state_group_for_event",
num_args=1, inlineCallbacks=True) list_name="event_ids", num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids): def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group """Returns mapping event_id -> state_group
""" """

View File

@ -167,7 +167,8 @@ class CacheDescriptor(object):
% (orig.__name__,) % (orig.__name__,)
) )
self.cache = Cache( def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
@ -175,14 +176,12 @@ class CacheDescriptor(object):
tree=self.tree, tree=self.tree,
) )
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try: try:
cached_result_d = self.cache.get(cache_key) cached_result_d = cache.get(cache_key)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -204,7 +203,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = self.cache.sequence sequence = cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, preserve_context_over_fn,
@ -213,20 +212,21 @@ class CacheDescriptor(object):
) )
def onErr(f): def onErr(f):
self.cache.invalidate(cache_key) cache.invalidate(cache_key)
return f return f
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret) cache.update(sequence, cache_key, ret)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = self.cache.invalidate_many wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = self.cache.prefill wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped
@ -240,11 +240,12 @@ class CacheListDescriptor(object):
the list of missing keys to the wrapped fucntion. the list of missing keys to the wrapped fucntion.
""" """
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): def __init__(self, orig, cached_method_name, list_name, num_args=1,
inlineCallbacks=False):
""" """
Args: Args:
orig (function) orig (function)
cache (Cache) method_name (str); The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list list_name (str): Name of the argument which is the bulk lookup list
num_args (int) num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should inlineCallbacks (bool): Whether orig is a generator that should
@ -263,7 +264,7 @@ class CacheListDescriptor(object):
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache self.cached_method_name = cached_method_name
self.sentinel = object() self.sentinel = object()
@ -277,11 +278,13 @@ class CacheListDescriptor(object):
if self.list_name not in self.arg_names: if self.list_name not in self.arg_names:
raise Exception( raise Exception(
"Couldn't see arguments %r for %r." "Couldn't see arguments %r for %r."
% (self.list_name, cache.name,) % (self.list_name, cached_method_name,)
) )
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cache = getattr(obj, self.cached_method_name).cache
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@ -297,14 +300,14 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = self.cache.get(tuple(key)).observe() res = cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res cached[arg] = res
except KeyError: except KeyError:
missing.append(arg) missing.append(arg)
if missing: if missing:
sequence = self.cache.sequence sequence = cache.sequence
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
@ -327,10 +330,10 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer) cache.update(sequence, tuple(key), observer)
def invalidate(f, key): def invalidate(f, key):
self.cache.invalidate(key) cache.invalidate(key)
return f return f
observer.addErrback(invalidate, tuple(key)) observer.addErrback(invalidate, tuple(key))
@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
) )
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`. """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
""" """
return lambda orig: CacheListDescriptor( return lambda orig: CacheListDescriptor(
orig, orig,
cache=cache, cached_method_name=cached_method_name,
list_name=list_name, list_name=list_name,
num_args=num_args, num_args=num_args,
inlineCallbacks=inlineCallbacks, inlineCallbacks=inlineCallbacks,