diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index b2c94bfaac..ed2ccc4dfb 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,7 +40,7 @@ class ActionGenerator: def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "evaluator_for_event"): bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context.current_state + event, self.hs, self.store, context.state_group, context.current_state ) with Measure(self.clock, "action_for_event_by_user"): diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 756e5da513..004eded61f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store): @defer.inlineCallbacks -def evaluator_for_event(event, hs, store, current_state): - room_id = event.room_id - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - local_users_in_room = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN - and hs.is_mine_id(e.state_key) +def evaluator_for_event(event, hs, store, state_group, current_state): + rules_by_user = yield store.bulk_get_push_rules_for_room( + event.room_id, state_group, current_state ) - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield store.get_if_users_have_pushers( - local_users_in_room - ) - user_ids = set( - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) - - users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == 'm.room.member' and event.content['membership'] == 'invite': @@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state): if invited_user and hs.is_mine_id(invited_user): has_pusher = yield store.user_has_pusher(invited_user) if has_pusher: - user_ids.add(invited_user) - - rules_by_user = yield _get_rules(room_id, user_ids, store) + rules_by_user[invited_user] = yield store.get_push_rules_for_user( + invited_user + ) defer.returnValue(BulkPushRuleEvaluator( - room_id, rules_by_user, user_ids, store + event.room_id, rules_by_user, store )) @@ -90,10 +66,9 @@ class BulkPushRuleEvaluator: the same logic to run the actual rules, but could be optimised further (see https://matrix.org/jira/browse/SYN-562) """ - def __init__(self, room_id, rules_by_user, users_in_room, store): + def __init__(self, room_id, rules_by_user, store): self.room_id = room_id self.rules_by_user = rules_by_user - self.users_in_room = users_in_room self.store = store @defer.inlineCallbacks diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 029f6612e6..49fa8614f2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -166,7 +166,7 @@ class SQLBaseStore(object): self._txn_perf_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters() - self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, + self._get_event_cache = Cache("*getEvent*", keylen=3, max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index df4000d0da..c65c9c9c47 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore): ) self._simple_insert_many_txn(txn, "event_push_actions", values) - @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000) + @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 8183b7f1b0..78334a98cf 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -16,6 +16,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.push.baserules import list_with_base_rules +from synapse.api.constants import EventTypes, Membership from twisted.internet import defer import logging @@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map): class PushRuleStore(SQLBaseStore): - @cachedInlineCallbacks(lru=True) + @cachedInlineCallbacks() def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", @@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rules) - @cachedInlineCallbacks(lru=True) + @cachedInlineCallbacks() def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", @@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) + def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, + cache_context): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + # We also will want to generate notifs for other people in the room so + # their unread countss are correct in the event stream, but to avoid + # generating them for bot / AS users etc, we only do so for people who've + # sent a read receipt into the room. + local_users_in_room = set( + e.state_key for e in current_state.values() + if e.type == EventTypes.Member and e.membership == Membership.JOIN + and self.hs.is_mine_id(e.state_key) + ) + + # users in the room who have pushers need to get push rules run because + # that's how their pushers work + if_users_with_pushers = yield self.get_if_users_have_pushers( + local_users_in_room, on_invalidate=cache_context.invalidate, + ) + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + ) + + users_with_receipts = yield self.get_users_with_read_receipts_in_room( + room_id, on_invalidate=cache_context.invalidate, + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in local_users_in_room: + user_ids.add(uid) + + rules_by_user = yield self.bulk_get_push_rules( + user_ids, on_invalidate=cache_context.invalidate, + ) + + rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} + + defer.returnValue(rules_by_user) + @cachedList(cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules_enabled(self, user_ids): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index a7d7c54d7e..8f5f8f24a9 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers", get_all_updated_pushers_txn ) - @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000) + @cachedInlineCallbacks(num_args=1, max_entries=15000) def get_if_user_has_pusher(self, user_id): result = yield self._simple_select_many_batch( table='pushers', diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 8c26f39fbb..3ad916103f 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -120,7 +120,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True) + @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index ea6823f18d..e1dca927d7 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" - @cached(lru=True) + @cached() def get_event_reference_hash(self, event_id): return self._get_event_reference_hashes_txn(event_id) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5b743db67a..0e8fa93e1f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -174,7 +174,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, lru=True, max_entries=1000) + @cached(num_args=2, max_entries=1000) def _get_state_group_from_group(self, group, types): raise NotImplementedError() @@ -272,7 +272,7 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) - @cached(num_args=2, lru=True, max_entries=10000) + @cached(num_args=2, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index f31dfb22b7..8dba61d49f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,8 +25,7 @@ from synapse.util.logcontext import ( from . import DEBUG_CACHES, register_cache from twisted.internet import defer - -from collections import OrderedDict +from collections import namedtuple import os import functools @@ -54,16 +53,11 @@ class Cache(object): "metrics", ) - def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): - if lru: - cache_type = TreeCache if tree else dict - self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type - ) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries + def __init__(self, name, max_entries=1000, keylen=1, tree=False): + cache_type = TreeCache if tree else dict + self.cache = LruCache( + max_size=max_entries, keylen=keylen, cache_type=cache_type + ) self.name = name self.keylen = keylen @@ -81,8 +75,8 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) + def get(self, key, default=_CacheSentinel, callback=None): + val = self.cache.get(key, _CacheSentinel, callback=callback) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -94,19 +88,15 @@ class Cache(object): else: return default - def update(self, sequence, key, value): + def update(self, sequence, key, value, callback=None): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) + self.prefill(key, value, callback=callback) - def prefill(self, key, value): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) - - self.cache[key] = value + def prefill(self, key, value, callback=None): + self.cache.set(key, value, callback=callback) def invalidate(self, key): self.check_thread() @@ -151,9 +141,21 @@ class CacheDescriptor(object): The wrapped function has another additional callable, called "prefill", which can be used to insert values into the cache specifically, without calling the calculation function. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example:: + + @cachedInlineCallbacks(cache_context=True) + def foo(self, key, cache_context): + r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) + defer.returnValue(r1 + r2) + """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, - inlineCallbacks=False): + def __init__(self, orig, max_entries=1000, num_args=1, tree=False, + inlineCallbacks=False, cache_context=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -165,15 +167,33 @@ class CacheDescriptor(object): self.max_entries = max_entries self.num_args = num_args - self.lru = lru self.tree = tree - self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] + all_args = inspect.getargspec(orig) + self.arg_names = all_args.args[1:num_args + 1] + + if "cache_context" in all_args.args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + try: + self.arg_names.remove("cache_context") + except ValueError: + pass + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) + + self.add_cache_context = cache_context if len(self.arg_names) < self.num_args: raise Exception( "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" + " (@cached cannot key off of *args or **kwargs)" % (orig.__name__,) ) @@ -182,16 +202,29 @@ class CacheDescriptor(object): name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, - lru=self.lru, tree=self.tree, ) @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + + # Add temp cache_context so inspect.getcallargs doesn't explode + if self.add_cache_context: + kwargs["cache_context"] = None + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext(cache, cache_key) + try: - cached_result_d = cache.get(cache_key) + cached_result_d = cache.get(cache_key, callback=invalidate_callback) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -228,7 +261,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -297,6 +330,10 @@ class CacheListDescriptor(object): @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] @@ -311,7 +348,7 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = cache.get(tuple(key)) + res = cache.get(tuple(key), callback=invalidate_callback) if not res.has_succeeded(): res = res.observe() res.addCallback(lambda r, arg: (arg, r), arg) @@ -345,7 +382,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update(sequence, tuple(key), observer) + cache.update( + sequence, tuple(key), observer, + callback=invalidate_callback + ) def invalidate(f, key): cache.invalidate(key) @@ -376,24 +416,29 @@ class CacheListDescriptor(object): return wrapped -def cached(max_entries=1000, num_args=1, lru=True, tree=False): +class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): + def invalidate(self): + self.cache.invalidate(self.key) + + +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, + cache_context=cache_context, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, inlineCallbacks=True, + cache_context=cache_context, ) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index f9df445a8d..9c4c679175 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -30,13 +30,14 @@ def enumerate_leaves(node, depth): class _Node(object): - __slots__ = ["prev_node", "next_node", "key", "value"] + __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value): + def __init__(self, prev_node, next_node, key, value, callbacks=set()): self.prev_node = prev_node self.next_node = next_node self.key = key self.value = value + self.callbacks = callbacks class LruCache(object): @@ -44,6 +45,9 @@ class LruCache(object): Least-recently-used cache. Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. + + Can also set callbacks on objects when getting/setting which are fired + when that key gets invalidated/evicted. """ def __init__(self, max_size, keylen=1, cache_type=dict): cache = cache_type() @@ -62,10 +66,10 @@ class LruCache(object): return inner - def add_node(key, value): + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value) + node = _Node(prev_node, next_node, key, value, callbacks) prev_node.next_node = node next_node.prev_node = node cache[key] = node @@ -88,23 +92,41 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + for cb in node.callbacks: + cb() + node.callbacks.clear() + @synchronized - def cache_get(key, default=None): + def cache_get(key, default=None, callback=None): node = cache.get(key, None) if node is not None: move_node_to_front(node) + if callback: + node.callbacks.add(callback) return node.value else: return default @synchronized - def cache_set(key, value): + def cache_set(key, value, callback=None): node = cache.get(key, None) if node is not None: + if value != node.value: + for cb in node.callbacks: + cb() + node.callbacks.clear() + + if callback: + node.callbacks.add(callback) + move_node_to_front(node) node.value = value else: - add_node(key, value) + if callback: + callbacks = set([callback]) + else: + callbacks = set() + add_node(key, value, callbacks) if len(cache) > max_size: todelete = list_root.prev_node delete_node(todelete) @@ -148,6 +170,9 @@ class LruCache(object): def cache_clear(): list_root.next_node = list_root list_root.prev_node = list_root + for node in cache.values(): + for cb in node.callbacks: + cb() cache.clear() @synchronized diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 03bc1401b7..c31585aea3 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -64,6 +64,9 @@ class TreeCache(object): self.size -= cnt return popped + def values(self): + return [e.value for e in self.root.values()] + def __len__(self): return self.size diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 96b7dba5fe..ab6095564a 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -17,6 +17,8 @@ from tests import unittest from twisted.internet import defer +from mock import Mock + from synapse.util.async import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached @@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) def test_eviction_lru(self): - cache = Cache("test", max_entries=2, lru=True) + cache = Cache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") @@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) + + @defer.inlineCallbacks + def test_invalidate_context(self): + callcount = [0] + callcount2 = [0] + + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func.invalidate(("foo",)) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 1) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + @defer.inlineCallbacks + def test_eviction_context(self): + callcount = [0] + callcount2 = [0] + + class A(object): + @cached(max_entries=2) + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + yield a.func2("foo2") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func("foo3") + + self.assertEquals(callcount[0], 3) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 4) + self.assertEquals(callcount2[0], 3) + + @defer.inlineCallbacks + def test_double_get(self): + callcount = [0] + callcount2 = [0] + + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + a.func2.cache.cache = Mock(wraps=a.func2.cache.cache) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + + yield a.func2("foo") + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 2) + + a.func.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 3) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index bab366fb7f..1eba5b535e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -19,6 +19,8 @@ from .. import unittest from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache +from mock import Mock + class LruCacheTestCase(unittest.TestCase): @@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.get("key"), 1) + cache["key"] = 2 # Make sure overriding works. + self.assertEquals(cache.get("key"), 2) def test_pop(self): cache = LruCache(1) @@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase): cache["key"] = 1 cache.clear() self.assertEquals(len(cache), 0) + + +class LruCacheCallbacksTestCase(unittest.TestCase): + def test_get(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.get("key", "value") + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + def test_multi_get(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + def test_set(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value", m) + self.assertFalse(m.called) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + def test_pop(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value", m) + self.assertFalse(m.called) + + cache.pop("key") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + cache.pop("key") + self.assertEquals(m.call_count, 1) + + def test_del_multi(self): + m1 = Mock() + m2 = Mock() + m3 = Mock() + m4 = Mock() + cache = LruCache(4, 2, cache_type=TreeCache) + + cache.set(("a", "1"), "value", m1) + cache.set(("a", "2"), "value", m2) + cache.set(("b", "1"), "value", m3) + cache.set(("b", "2"), "value", m4) + + self.assertEquals(m1.call_count, 0) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + self.assertEquals(m4.call_count, 0) + + cache.del_multi(("a",)) + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 1) + self.assertEquals(m3.call_count, 0) + self.assertEquals(m4.call_count, 0) + + def test_clear(self): + m1 = Mock() + m2 = Mock() + cache = LruCache(5) + + cache.set("key1", "value", m1) + cache.set("key2", "value", m2) + + self.assertEquals(m1.call_count, 0) + self.assertEquals(m2.call_count, 0) + + cache.clear() + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 1) + + def test_eviction(self): + m1 = Mock(name="m1") + m2 = Mock(name="m2") + m3 = Mock(name="m3") + cache = LruCache(2) + + cache.set("key1", "value", m1) + cache.set("key2", "value", m2) + + self.assertEquals(m1.call_count, 0) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.set("key3", "value", m3) + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.set("key3", "value") + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.get("key2") + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 0) + + cache.set("key1", "value", m1) + + self.assertEquals(m1.call_count, 1) + self.assertEquals(m2.call_count, 0) + self.assertEquals(m3.call_count, 1)