Merge pull request #1030 from matrix-org/erikj/cache_contexts

Add concept of cache contexts
This commit is contained in:
Erik Johnston 2016-08-19 16:29:58 +01:00 committed by GitHub
commit e6784daf07
14 changed files with 458 additions and 87 deletions

View File

@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"): with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.current_state event, self.hs, self.store, context.state_group, context.current_state
) )
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):

View File

@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, current_state): def evaluator_for_event(event, hs, store, state_group, current_state):
room_id = event.room_id rules_by_user = yield store.bulk_get_push_rules_for_room(
# We also will want to generate notifs for other people in the room so event.room_id, state_group, current_state
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
) )
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite': if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
if invited_user and hs.is_mine_id(invited_user): if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher: if has_pusher:
user_ids.add(invited_user) rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user
rules_by_user = yield _get_rules(room_id, user_ids, store) )
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, user_ids, store event.room_id, rules_by_user, store
)) ))
@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, users_in_room, store): def __init__(self, room_id, rules_by_user, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -166,7 +166,7 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters() self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(

View File

@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
) )
self._simple_insert_many_txn(txn, "event_push_actions", values) self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000) @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
): ):

View File

@ -16,6 +16,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules) defer.returnValue(rules)
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",
@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
cache_context):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and self.hs.is_mine_id(e.state_key)
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@cachedList(cached_method_name="get_push_rules_enabled_for_user", @cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True) list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids): def bulk_get_push_rules_enabled(self, user_ids):

View File

@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn "get_all_updated_pushers", get_all_updated_pushers_txn
) )
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000) @cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id): def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch( result = yield self._simple_select_many_batch(
table='pushers', table='pushers',

View File

@ -120,7 +120,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True) @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.

View File

@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
@cached(lru=True) @cached()
def get_event_reference_hash(self, event_id): def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id) return self._get_event_reference_hashes_txn(event_id)

View File

@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results] return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f) return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, lru=True, max_entries=1000) @cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types): def _get_state_group_from_group(self, group, types):
raise NotImplementedError() raise NotImplementedError()
@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000) @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="event_to_state_groups", table="event_to_state_groups",

View File

@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache from . import DEBUG_CACHES, register_cache
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
from collections import OrderedDict
import os import os
import functools import functools
@ -54,16 +53,11 @@ class Cache(object):
"metrics", "metrics",
) )
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): def __init__(self, name, max_entries=1000, keylen=1, tree=False):
if lru:
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type max_size=max_entries, keylen=keylen, cache_type=cache_type
) )
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, key, default=_CacheSentinel): def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel) val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel: if val is not _CacheSentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
@ -94,19 +88,15 @@ class Cache(object):
else: else:
return default return default
def update(self, sequence, key, value): def update(self, sequence, key, value, callback=None):
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value) self.prefill(key, value, callback=callback)
def prefill(self, key, value): def prefill(self, key, value, callback=None):
if self.max_entries is not None: self.cache.set(key, value, callback=callback)
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill", The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
""" """
def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False): inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig self.orig = orig
@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args self.num_args = num_args
self.lru = lru
self.tree = tree self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args: if len(self.arg_names) < self.num_args:
raise Exception( raise Exception(
"Not enough explicit positional arguments to key off of for %r." "Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)" " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,) % (orig.__name__,)
) )
@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
lru=self.lru,
tree=self.tree, tree=self.tree,
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
# Add temp cache_context so inspect.getcallargs doesn't explode
if self.add_cache_context:
kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext(cache, cache_key)
try: try:
cached_result_d = cache.get(cache_key) cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret) cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = cache.get(tuple(key)) res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded(): if not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update(sequence, tuple(key), observer) cache.update(
sequence, tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key): def invalidate(f, key):
cache.invalidate(key) cache.invalidate(key)
@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped return wrapped
def cached(max_entries=1000, num_args=1, lru=True, tree=False): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
lru=lru,
tree=tree, tree=tree,
cache_context=cache_context,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
lru=lru,
tree=tree, tree=tree,
inlineCallbacks=True, inlineCallbacks=True,
cache_context=cache_context,
) )

View File

@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
class _Node(object): class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value"] __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value): def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node self.prev_node = prev_node
self.next_node = next_node self.next_node = next_node
self.key = key self.key = key
self.value = value self.value = value
self.callbacks = callbacks
class LruCache(object): class LruCache(object):
@ -44,6 +45,9 @@ class LruCache(object):
Least-recently-used cache. Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples. If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type() cache = cache_type()
@ -62,10 +66,10 @@ class LruCache(object):
return inner return inner
def add_node(key, value): def add_node(key, value, callbacks=set()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value) node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node prev_node.next_node = node
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
@ -88,23 +92,41 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized @synchronized
def cache_get(key, default=None): def cache_get(key, default=None, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback:
node.callbacks.add(callback)
return node.value return node.value
else: else:
return default return default
@synchronized @synchronized
def cache_set(key, value): def cache_set(key, value, callback=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
if value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
add_node(key, value) if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size: if len(cache) > max_size:
todelete = list_root.prev_node todelete = list_root.prev_node
delete_node(todelete) delete_node(todelete)
@ -148,6 +170,9 @@ class LruCache(object):
def cache_clear(): def cache_clear():
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
cache.clear() cache.clear()
@synchronized @synchronized

View File

@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt self.size -= cnt
return popped return popped
def values(self):
return [e.value for e in self.root.values()]
def __len__(self): def __len__(self):
return self.size return self.size

View File

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.descriptors import Cache, cached
@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3) cache.get(3)
def test_eviction_lru(self): def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True) cache = Cache("test", max_entries=2)
cache.prefill(1, "one") cache.prefill(1, "one")
cache.prefill(2, "two") cache.prefill(2, "two")
@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)

View File

@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
from mock import Mock
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
cache["key"] = 2 # Make sure overriding works.
self.assertEquals(cache.get("key"), 2)
def test_pop(self): def test_pop(self):
cache = LruCache(1) cache = LruCache(1)
@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1 cache["key"] = 1
cache.clear() cache.clear()
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
class LruCacheCallbacksTestCase(unittest.TestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_multi_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_set(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.set("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_pop(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
self.assertFalse(m.called)
cache.pop("key")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
cache.pop("key")
self.assertEquals(m.call_count, 1)
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
cache.del_multi(("a",))
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
self.assertEquals(m3.call_count, 0)
self.assertEquals(m4.call_count, 0)
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
cache.clear()
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 1)
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.get("key2")
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)