Merge pull request #695 from matrix-org/markjh/cachesII
Make the cache objects be per instance rather than being global
This commit is contained in:
commit
196ebaf662
|
@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
"content": content,
|
"content": content,
|
||||||
}])
|
}])
|
||||||
|
|
||||||
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
|
@cachedList(cached_method_name="get_linearized_receipts_for_room",
|
||||||
num_args=3, inlineCallbacks=True)
|
list_name="room_ids", num_args=3, inlineCallbacks=True)
|
||||||
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||||
if not room_ids:
|
if not room_ids:
|
||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
|
@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
|
@cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1,
|
||||||
inlineCallbacks=True)
|
inlineCallbacks=True)
|
||||||
def are_guests(self, user_ids):
|
def are_guests(self, user_ids):
|
||||||
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
|
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
|
||||||
|
|
|
@ -273,8 +273,8 @@ class StateStore(SQLBaseStore):
|
||||||
desc="_get_state_group_for_event",
|
desc="_get_state_group_for_event",
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
|
@cachedList(cached_method_name="_get_state_group_for_event",
|
||||||
num_args=1, inlineCallbacks=True)
|
list_name="event_ids", num_args=1, inlineCallbacks=True)
|
||||||
def _get_state_group_for_events(self, event_ids):
|
def _get_state_group_for_events(self, event_ids):
|
||||||
"""Returns mapping event_id -> state_group
|
"""Returns mapping event_id -> state_group
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -167,7 +167,8 @@ class CacheDescriptor(object):
|
||||||
% (orig.__name__,)
|
% (orig.__name__,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache = Cache(
|
def __get__(self, obj, objtype=None):
|
||||||
|
cache = Cache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
keylen=self.num_args,
|
keylen=self.num_args,
|
||||||
|
@ -175,14 +176,12 @@ class CacheDescriptor(object):
|
||||||
tree=self.tree,
|
tree=self.tree,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||||
try:
|
try:
|
||||||
cached_result_d = self.cache.get(cache_key)
|
cached_result_d = cache.get(cache_key)
|
||||||
|
|
||||||
observer = cached_result_d.observe()
|
observer = cached_result_d.observe()
|
||||||
if DEBUG_CACHES:
|
if DEBUG_CACHES:
|
||||||
|
@ -204,7 +203,7 @@ class CacheDescriptor(object):
|
||||||
# Get the sequence number of the cache before reading from the
|
# Get the sequence number of the cache before reading from the
|
||||||
# database so that we can tell if the cache is invalidated
|
# database so that we can tell if the cache is invalidated
|
||||||
# while the SELECT is executing (SYN-369)
|
# while the SELECT is executing (SYN-369)
|
||||||
sequence = self.cache.sequence
|
sequence = cache.sequence
|
||||||
|
|
||||||
ret = defer.maybeDeferred(
|
ret = defer.maybeDeferred(
|
||||||
preserve_context_over_fn,
|
preserve_context_over_fn,
|
||||||
|
@ -213,20 +212,21 @@ class CacheDescriptor(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
def onErr(f):
|
def onErr(f):
|
||||||
self.cache.invalidate(cache_key)
|
cache.invalidate(cache_key)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
ret.addErrback(onErr)
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||||
self.cache.update(sequence, cache_key, ret)
|
cache.update(sequence, cache_key, ret)
|
||||||
|
|
||||||
return preserve_context_over_deferred(ret.observe())
|
return preserve_context_over_deferred(ret.observe())
|
||||||
|
|
||||||
wrapped.invalidate = self.cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = self.cache.invalidate_all
|
wrapped.invalidate_all = cache.invalidate_all
|
||||||
wrapped.invalidate_many = self.cache.invalidate_many
|
wrapped.invalidate_many = cache.invalidate_many
|
||||||
wrapped.prefill = self.cache.prefill
|
wrapped.prefill = cache.prefill
|
||||||
|
wrapped.cache = cache
|
||||||
|
|
||||||
obj.__dict__[self.orig.__name__] = wrapped
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
|
@ -240,11 +240,12 @@ class CacheListDescriptor(object):
|
||||||
the list of missing keys to the wrapped fucntion.
|
the list of missing keys to the wrapped fucntion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
|
def __init__(self, orig, cached_method_name, list_name, num_args=1,
|
||||||
|
inlineCallbacks=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
orig (function)
|
orig (function)
|
||||||
cache (Cache)
|
method_name (str); The name of the chached method.
|
||||||
list_name (str): Name of the argument which is the bulk lookup list
|
list_name (str): Name of the argument which is the bulk lookup list
|
||||||
num_args (int)
|
num_args (int)
|
||||||
inlineCallbacks (bool): Whether orig is a generator that should
|
inlineCallbacks (bool): Whether orig is a generator that should
|
||||||
|
@ -263,7 +264,7 @@ class CacheListDescriptor(object):
|
||||||
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
||||||
self.list_pos = self.arg_names.index(self.list_name)
|
self.list_pos = self.arg_names.index(self.list_name)
|
||||||
|
|
||||||
self.cache = cache
|
self.cached_method_name = cached_method_name
|
||||||
|
|
||||||
self.sentinel = object()
|
self.sentinel = object()
|
||||||
|
|
||||||
|
@ -277,11 +278,13 @@ class CacheListDescriptor(object):
|
||||||
if self.list_name not in self.arg_names:
|
if self.list_name not in self.arg_names:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Couldn't see arguments %r for %r."
|
"Couldn't see arguments %r for %r."
|
||||||
% (self.list_name, cache.name,)
|
% (self.list_name, cached_method_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
|
|
||||||
|
cache = getattr(obj, self.cached_method_name).cache
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
|
@ -297,14 +300,14 @@ class CacheListDescriptor(object):
|
||||||
key[self.list_pos] = arg
|
key[self.list_pos] = arg
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = self.cache.get(tuple(key)).observe()
|
res = cache.get(tuple(key)).observe()
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
res.addCallback(lambda r, arg: (arg, r), arg)
|
||||||
cached[arg] = res
|
cached[arg] = res
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing.append(arg)
|
missing.append(arg)
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
sequence = self.cache.sequence
|
sequence = cache.sequence
|
||||||
args_to_call = dict(arg_dict)
|
args_to_call = dict(arg_dict)
|
||||||
args_to_call[self.list_name] = missing
|
args_to_call[self.list_name] = missing
|
||||||
|
|
||||||
|
@ -327,10 +330,10 @@ class CacheListDescriptor(object):
|
||||||
|
|
||||||
key = list(keyargs)
|
key = list(keyargs)
|
||||||
key[self.list_pos] = arg
|
key[self.list_pos] = arg
|
||||||
self.cache.update(sequence, tuple(key), observer)
|
cache.update(sequence, tuple(key), observer)
|
||||||
|
|
||||||
def invalidate(f, key):
|
def invalidate(f, key):
|
||||||
self.cache.invalidate(key)
|
cache.invalidate(key)
|
||||||
return f
|
return f
|
||||||
observer.addErrback(invalidate, tuple(key))
|
observer.addErrback(invalidate, tuple(key))
|
||||||
|
|
||||||
|
@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
|
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
|
||||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||||
|
|
||||||
Used to do batch lookups for an already created cache. A single argument
|
Used to do batch lookups for an already created cache. A single argument
|
||||||
|
@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
|
||||||
"""
|
"""
|
||||||
return lambda orig: CacheListDescriptor(
|
return lambda orig: CacheListDescriptor(
|
||||||
orig,
|
orig,
|
||||||
cache=cache,
|
cached_method_name=cached_method_name,
|
||||||
list_name=list_name,
|
list_name=list_name,
|
||||||
num_args=num_args,
|
num_args=num_args,
|
||||||
inlineCallbacks=inlineCallbacks,
|
inlineCallbacks=inlineCallbacks,
|
||||||
|
|
Loading…
Reference in New Issue