Merge pull request #8572 from matrix-org/rav/cache_hacking/2
Push some deferred wrangling down into DeferredCache
This commit is contained in:
commit
9146a8a691
|
@ -0,0 +1 @@
|
||||||
|
Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.
|
|
@ -57,7 +57,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
"""Wraps an LruCache, adding support for Deferred results.
|
"""Wraps an LruCache, adding support for Deferred results.
|
||||||
|
|
||||||
It expects that each entry added with set() will be a Deferred; likewise get()
|
It expects that each entry added with set() will be a Deferred; likewise get()
|
||||||
may return an ObservableDeferred.
|
will return a Deferred.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
|
@ -130,16 +130,22 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
key: KT,
|
key: KT,
|
||||||
callback: Optional[Callable[[], None]] = None,
|
callback: Optional[Callable[[], None]] = None,
|
||||||
update_metrics: bool = True,
|
update_metrics: bool = True,
|
||||||
) -> Union[ObservableDeferred, VT]:
|
) -> defer.Deferred:
|
||||||
"""Looks the key up in the caches.
|
"""Looks the key up in the caches.
|
||||||
|
|
||||||
|
For symmetry with set(), this method does *not* follow the synapse logcontext
|
||||||
|
rules: the logcontext will not be cleared on return, and the Deferred will run
|
||||||
|
its callbacks in the sentinel context. In other words: wrap the result with
|
||||||
|
make_deferred_yieldable() before `await`ing it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key(tuple)
|
key:
|
||||||
callback(fn): Gets called when the entry in the cache is invalidated
|
callback: Gets called when the entry in the cache is invalidated
|
||||||
update_metrics (bool): whether to update the cache hit rate metrics
|
update_metrics (bool): whether to update the cache hit rate metrics
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either an ObservableDeferred or the result itself
|
A Deferred which completes with the result. Note that this may later fail
|
||||||
|
if there is an ongoing set() operation which later completes with a failure.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError if the key is not found in the cache
|
KeyError if the key is not found in the cache
|
||||||
|
@ -152,7 +158,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
m = self.cache.metrics
|
m = self.cache.metrics
|
||||||
assert m # we always have a name, so should always have metrics
|
assert m # we always have a name, so should always have metrics
|
||||||
m.inc_hits()
|
m.inc_hits()
|
||||||
return val.deferred
|
return val.deferred.observe()
|
||||||
|
|
||||||
val2 = self.cache.get(
|
val2 = self.cache.get(
|
||||||
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
|
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
|
||||||
|
@ -160,7 +166,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
if val2 is _Sentinel.sentinel:
|
if val2 is _Sentinel.sentinel:
|
||||||
raise KeyError()
|
raise KeyError()
|
||||||
else:
|
else:
|
||||||
return val2
|
return defer.succeed(val2)
|
||||||
|
|
||||||
def get_immediate(
|
def get_immediate(
|
||||||
self, key: KT, default: T, update_metrics: bool = True
|
self, key: KT, default: T, update_metrics: bool = True
|
||||||
|
@ -173,7 +179,36 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
key: KT,
|
key: KT,
|
||||||
value: defer.Deferred,
|
value: defer.Deferred,
|
||||||
callback: Optional[Callable[[], None]] = None,
|
callback: Optional[Callable[[], None]] = None,
|
||||||
) -> ObservableDeferred:
|
) -> defer.Deferred:
|
||||||
|
"""Adds a new entry to the cache (or updates an existing one).
|
||||||
|
|
||||||
|
The given `value` *must* be a Deferred.
|
||||||
|
|
||||||
|
First any existing entry for the same key is invalidated. Then a new entry
|
||||||
|
is added to the cache for the given key.
|
||||||
|
|
||||||
|
Until the `value` completes, calls to `get()` for the key will also result in an
|
||||||
|
incomplete Deferred, which will ultimately complete with the same result as
|
||||||
|
`value`.
|
||||||
|
|
||||||
|
If `value` completes successfully, subsequent calls to `get()` will then return
|
||||||
|
a completed deferred with the same result. If it *fails*, the cache is
|
||||||
|
invalidated and subequent calls to `get()` will raise a KeyError.
|
||||||
|
|
||||||
|
If another call to `set()` happens before `value` completes, then (a) any
|
||||||
|
invalidation callbacks registered in the interim will be called, (b) any
|
||||||
|
`get()`s in the interim will continue to complete with the result from the
|
||||||
|
*original* `value`, (c) any future calls to `get()` will complete with the
|
||||||
|
result from the *new* `value`.
|
||||||
|
|
||||||
|
It is expected that `value` does *not* follow the synapse logcontext rules - ie,
|
||||||
|
if it is incomplete, it runs its callbacks in the sentinel context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Key to be set
|
||||||
|
value: a deferred which will complete with a result to add to the cache
|
||||||
|
callback: An optional callback to be called when the entry is invalidated
|
||||||
|
"""
|
||||||
if not isinstance(value, defer.Deferred):
|
if not isinstance(value, defer.Deferred):
|
||||||
raise TypeError("not a Deferred")
|
raise TypeError("not a Deferred")
|
||||||
|
|
||||||
|
@ -187,6 +222,8 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
if existing_entry:
|
if existing_entry:
|
||||||
existing_entry.invalidate()
|
existing_entry.invalidate()
|
||||||
|
|
||||||
|
# XXX: why don't we invalidate the entry in `self.cache` yet?
|
||||||
|
|
||||||
self._pending_deferred_cache[key] = entry
|
self._pending_deferred_cache[key] = entry
|
||||||
|
|
||||||
def compare_and_pop():
|
def compare_and_pop():
|
||||||
|
@ -230,7 +267,9 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
# _pending_deferred_cache to the real cache.
|
# _pending_deferred_cache to the real cache.
|
||||||
#
|
#
|
||||||
observer.addCallbacks(cb, eb)
|
observer.addCallbacks(cb, eb)
|
||||||
return observable
|
|
||||||
|
# we return a new Deferred which will be called before any subsequent observers.
|
||||||
|
return observable.observe()
|
||||||
|
|
||||||
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
|
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
|
|
|
@ -23,7 +23,6 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.deferred_cache import DeferredCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
keylen=self.num_args,
|
keylen=self.num_args,
|
||||||
tree=self.tree,
|
tree=self.tree,
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
) # type: DeferredCache[Tuple, Any]
|
) # type: DeferredCache[CacheKey, Any]
|
||||||
|
|
||||||
def get_cache_key_gen(args, kwargs):
|
def get_cache_key_gen(args, kwargs):
|
||||||
"""Given some args/kwargs return a generator that resolves into
|
"""Given some args/kwargs return a generator that resolves into
|
||||||
|
@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
|
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
|
ret = cache.get(cache_key, callback=invalidate_callback)
|
||||||
|
|
||||||
if isinstance(cached_result_d, ObservableDeferred):
|
|
||||||
observer = cached_result_d.observe()
|
|
||||||
else:
|
|
||||||
observer = defer.succeed(cached_result_d)
|
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
||||||
|
ret = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||||
|
|
||||||
def onErr(f):
|
return make_deferred_yieldable(ret)
|
||||||
cache.invalidate(cache_key)
|
|
||||||
return f
|
|
||||||
|
|
||||||
ret.addErrback(onErr)
|
|
||||||
|
|
||||||
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
|
|
||||||
observer = result_d.observe()
|
|
||||||
|
|
||||||
return make_deferred_yieldable(observer)
|
|
||||||
|
|
||||||
wrapped = cast(_CachedFunction, _wrapped)
|
wrapped = cast(_CachedFunction, _wrapped)
|
||||||
|
|
||||||
|
@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
cached_method = getattr(obj, self.cached_method_name)
|
cached_method = getattr(obj, self.cached_method_name)
|
||||||
cache = cached_method.cache
|
cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
|
||||||
num_args = cached_method.num_args
|
num_args = cached_method.num_args
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
|
@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
for arg in list_args:
|
for arg in list_args:
|
||||||
try:
|
try:
|
||||||
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
|
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
|
||||||
if not isinstance(res, ObservableDeferred):
|
if not res.called:
|
||||||
results[arg] = res
|
|
||||||
elif not res.has_succeeded():
|
|
||||||
res = res.observe()
|
|
||||||
res.addCallback(update_results_dict, arg)
|
res.addCallback(update_results_dict, arg)
|
||||||
cached_defers.append(res)
|
cached_defers.append(res)
|
||||||
else:
|
else:
|
||||||
results[arg] = res.get_result()
|
results[arg] = res.result
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing.add(arg)
|
missing.add(arg)
|
||||||
|
|
||||||
|
|
|
@ -15,237 +15,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from mock import Mock
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
|
||||||
from synapse.util.caches.descriptors import cached
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_passthrough(self):
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
|
|
||||||
self.assertEquals((yield a.func("foo")), "foo")
|
|
||||||
self.assertEquals((yield a.func("bar")), "bar")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_hit(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
|
|
||||||
self.assertEquals((yield a.func("foo")), "foo")
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_invalidate(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
|
|
||||||
a.func.invalidate(("foo",))
|
|
||||||
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
|
|
||||||
def test_invalidate_missing(self):
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
return key
|
|
||||||
|
|
||||||
A().func.invalidate(("what",))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_max_entries(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached(max_entries=10)
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
|
|
||||||
for k in range(0, 12):
|
|
||||||
yield a.func(k)
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 12)
|
|
||||||
|
|
||||||
# There must have been at least 2 evictions, meaning if we calculate
|
|
||||||
# all 12 values again, we must get called at least 2 more times
|
|
||||||
for k in range(0, 12):
|
|
||||||
yield a.func(k)
|
|
||||||
|
|
||||||
self.assertTrue(
|
|
||||||
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prefill(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
d = defer.succeed(123)
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return d
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
|
|
||||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
|
||||||
|
|
||||||
self.assertEquals(a.func("foo").result, d.result)
|
|
||||||
self.assertEquals(callcount[0], 0)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_invalidate_context(self):
|
|
||||||
callcount = [0]
|
|
||||||
callcount2 = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
@cached(cache_context=True)
|
|
||||||
def func2(self, key, cache_context):
|
|
||||||
callcount2[0] += 1
|
|
||||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
self.assertEquals(callcount2[0], 1)
|
|
||||||
|
|
||||||
a.func.invalidate(("foo",))
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 1)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_eviction_context(self):
|
|
||||||
callcount = [0]
|
|
||||||
callcount2 = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached(max_entries=2)
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
@cached(cache_context=True)
|
|
||||||
def func2(self, key, cache_context):
|
|
||||||
callcount2[0] += 1
|
|
||||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func2("foo")
|
|
||||||
yield a.func2("foo2")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func("foo3")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 3)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 4)
|
|
||||||
self.assertEquals(callcount2[0], 3)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_double_get(self):
|
|
||||||
callcount = [0]
|
|
||||||
callcount2 = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
@cached(cache_context=True)
|
|
||||||
def func2(self, key, cache_context):
|
|
||||||
callcount2[0] += 1
|
|
||||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
self.assertEquals(callcount2[0], 1)
|
|
||||||
|
|
||||||
a.func2.invalidate(("foo",))
|
|
||||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
a.func2.invalidate(("foo",))
|
|
||||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
a.func.invalidate(("foo",))
|
|
||||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 3)
|
|
||||||
|
|
||||||
|
|
||||||
class UpsertManyTests(unittest.HomeserverTestCase):
|
class UpsertManyTests(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.storage = hs.get_datastore()
|
self.storage = hs.get_datastore()
|
||||||
|
|
|
@ -13,15 +13,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.deferred_cache import DeferredCache
|
||||||
|
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
class DeferredCacheTestCase(unittest.TestCase):
|
|
||||||
|
class DeferredCacheTestCase(TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
cache = DeferredCache("test")
|
cache = DeferredCache("test")
|
||||||
failed = False
|
failed = False
|
||||||
|
@ -36,7 +37,102 @@ class DeferredCacheTestCase(unittest.TestCase):
|
||||||
cache = DeferredCache("test")
|
cache = DeferredCache("test")
|
||||||
cache.prefill("foo", 123)
|
cache.prefill("foo", 123)
|
||||||
|
|
||||||
self.assertEquals(cache.get("foo"), 123)
|
self.assertEquals(self.successResultOf(cache.get("foo")), 123)
|
||||||
|
|
||||||
|
def test_hit_deferred(self):
|
||||||
|
cache = DeferredCache("test")
|
||||||
|
origin_d = defer.Deferred()
|
||||||
|
set_d = cache.set("k1", origin_d)
|
||||||
|
|
||||||
|
# get should return an incomplete deferred
|
||||||
|
get_d = cache.get("k1")
|
||||||
|
self.assertFalse(get_d.called)
|
||||||
|
|
||||||
|
# add a callback that will make sure that the set_d gets called before the get_d
|
||||||
|
def check1(r):
|
||||||
|
self.assertTrue(set_d.called)
|
||||||
|
return r
|
||||||
|
|
||||||
|
# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
|
||||||
|
# maybe we should fix that?
|
||||||
|
# get_d.addCallback(check1)
|
||||||
|
|
||||||
|
# now fire off all the deferreds
|
||||||
|
origin_d.callback(99)
|
||||||
|
self.assertEqual(self.successResultOf(origin_d), 99)
|
||||||
|
self.assertEqual(self.successResultOf(set_d), 99)
|
||||||
|
self.assertEqual(self.successResultOf(get_d), 99)
|
||||||
|
|
||||||
|
def test_callbacks(self):
|
||||||
|
"""Invalidation callbacks are called at the right time"""
|
||||||
|
cache = DeferredCache("test")
|
||||||
|
callbacks = set()
|
||||||
|
|
||||||
|
# start with an entry, with a callback
|
||||||
|
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
|
||||||
|
|
||||||
|
# now replace that entry with a pending result
|
||||||
|
origin_d = defer.Deferred()
|
||||||
|
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
|
||||||
|
|
||||||
|
# ... and also make a get request
|
||||||
|
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
|
||||||
|
|
||||||
|
# we don't expect the invalidation callback for the original value to have
|
||||||
|
# been called yet, even though get() will now return a different result.
|
||||||
|
# I'm not sure if that is by design or not.
|
||||||
|
self.assertEqual(callbacks, set())
|
||||||
|
|
||||||
|
# now fire off all the deferreds
|
||||||
|
origin_d.callback(20)
|
||||||
|
self.assertEqual(self.successResultOf(set_d), 20)
|
||||||
|
self.assertEqual(self.successResultOf(get_d), 20)
|
||||||
|
|
||||||
|
# now the original invalidation callback should have been called, but none of
|
||||||
|
# the others
|
||||||
|
self.assertEqual(callbacks, {"prefill"})
|
||||||
|
callbacks.clear()
|
||||||
|
|
||||||
|
# another update should invalidate both the previous results
|
||||||
|
cache.prefill("k1", 30)
|
||||||
|
self.assertEqual(callbacks, {"set", "get"})
|
||||||
|
|
||||||
|
def test_set_fail(self):
|
||||||
|
cache = DeferredCache("test")
|
||||||
|
callbacks = set()
|
||||||
|
|
||||||
|
# start with an entry, with a callback
|
||||||
|
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
|
||||||
|
|
||||||
|
# now replace that entry with a pending result
|
||||||
|
origin_d = defer.Deferred()
|
||||||
|
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
|
||||||
|
|
||||||
|
# ... and also make a get request
|
||||||
|
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
|
||||||
|
|
||||||
|
# none of the callbacks should have been called yet
|
||||||
|
self.assertEqual(callbacks, set())
|
||||||
|
|
||||||
|
# oh noes! fails!
|
||||||
|
e = Exception("oops")
|
||||||
|
origin_d.errback(e)
|
||||||
|
self.assertIs(self.failureResultOf(set_d, Exception).value, e)
|
||||||
|
self.assertIs(self.failureResultOf(get_d, Exception).value, e)
|
||||||
|
|
||||||
|
# the callbacks for the failed requests should have been called.
|
||||||
|
# I'm not sure if this is deliberate or not.
|
||||||
|
self.assertEqual(callbacks, {"get", "set"})
|
||||||
|
callbacks.clear()
|
||||||
|
|
||||||
|
# the old value should still be returned now?
|
||||||
|
get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
|
||||||
|
self.assertEqual(self.successResultOf(get_d2), 10)
|
||||||
|
|
||||||
|
# replacing the value now should run the callbacks for those requests
|
||||||
|
# which got the original result
|
||||||
|
cache.prefill("k1", 30)
|
||||||
|
self.assertEqual(callbacks, {"prefill", "get2"})
|
||||||
|
|
||||||
def test_get_immediate(self):
|
def test_get_immediate(self):
|
||||||
cache = DeferredCache("test")
|
cache = DeferredCache("test")
|
||||||
|
@ -82,16 +178,15 @@ class DeferredCacheTestCase(unittest.TestCase):
|
||||||
d2 = defer.Deferred()
|
d2 = defer.Deferred()
|
||||||
cache.set("key2", d2, partial(record_callback, 1))
|
cache.set("key2", d2, partial(record_callback, 1))
|
||||||
|
|
||||||
# lookup should return observable deferreds
|
# lookup should return pending deferreds
|
||||||
self.assertFalse(cache.get("key1").has_called())
|
self.assertFalse(cache.get("key1").called)
|
||||||
self.assertFalse(cache.get("key2").has_called())
|
self.assertFalse(cache.get("key2").called)
|
||||||
|
|
||||||
# let one of the lookups complete
|
# let one of the lookups complete
|
||||||
d2.callback("result2")
|
d2.callback("result2")
|
||||||
|
|
||||||
# for now at least, the cache will return real results rather than an
|
# now the cache will return a completed deferred
|
||||||
# observabledeferred
|
self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
|
||||||
self.assertEqual(cache.get("key2"), "result2")
|
|
||||||
|
|
||||||
# now do the invalidation
|
# now do the invalidation
|
||||||
cache.invalidate_all()
|
cache.invalidate_all()
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
import mock
|
import mock
|
||||||
|
|
||||||
|
@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
d = obj.fn(1)
|
d = obj.fn(1)
|
||||||
self.failureResultOf(d, SynapseError)
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
def test_cache_with_async_exception(self):
|
||||||
|
"""The wrapped function returns a failure
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
result = None
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
self.call_count += 1
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
callbacks = set() # type: Set[str]
|
||||||
|
|
||||||
|
# set off an asynchronous request
|
||||||
|
obj.result = origin_d = defer.Deferred()
|
||||||
|
|
||||||
|
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
|
||||||
|
self.assertFalse(d1.called)
|
||||||
|
|
||||||
|
# a second request should also return a deferred, but should not call the
|
||||||
|
# function itself.
|
||||||
|
d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
|
||||||
|
self.assertFalse(d2.called)
|
||||||
|
self.assertEqual(obj.call_count, 1)
|
||||||
|
|
||||||
|
# no callbacks yet
|
||||||
|
self.assertEqual(callbacks, set())
|
||||||
|
|
||||||
|
# the original request fails
|
||||||
|
e = Exception("bzz")
|
||||||
|
origin_d.errback(e)
|
||||||
|
|
||||||
|
# ... which should cause the lookups to fail similarly
|
||||||
|
self.assertIs(self.failureResultOf(d1, Exception).value, e)
|
||||||
|
self.assertIs(self.failureResultOf(d2, Exception).value, e)
|
||||||
|
|
||||||
|
# ... and the callbacks to have been, uh, called.
|
||||||
|
self.assertEqual(callbacks, {"d1", "d2"})
|
||||||
|
|
||||||
|
# ... leaving the cache empty
|
||||||
|
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||||
|
|
||||||
|
# and a second call should work as normal
|
||||||
|
obj.result = defer.succeed(100)
|
||||||
|
d3 = obj.fn(1)
|
||||||
|
self.assertEqual(self.successResultOf(d3), 100)
|
||||||
|
self.assertEqual(obj.call_count, 2)
|
||||||
|
|
||||||
def test_cache_logcontexts(self):
|
def test_cache_logcontexts(self):
|
||||||
"""Check that logcontexts are set and restored correctly when
|
"""Check that logcontexts are set and restored correctly when
|
||||||
using the cache."""
|
using the cache."""
|
||||||
|
@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
self.failureResultOf(d, SynapseError)
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""More tests for @cached
|
||||||
|
|
||||||
|
The following is a set of tests that got lost in a different file for a while.
|
||||||
|
|
||||||
|
There are probably duplicates of the tests in DescriptorTestCase. Ideally the
|
||||||
|
duplicates would be removed and the two sets of classes combined.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_passthrough(self):
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
|
||||||
|
self.assertEquals((yield a.func("foo")), "foo")
|
||||||
|
self.assertEquals((yield a.func("bar")), "bar")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_hit(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
|
self.assertEquals((yield a.func("foo")), "foo")
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
|
||||||
|
def test_invalidate_missing(self):
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
return key
|
||||||
|
|
||||||
|
A().func.invalidate(("what",))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_max_entries(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached(max_entries=10)
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
|
||||||
|
for k in range(0, 12):
|
||||||
|
yield a.func(k)
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 12)
|
||||||
|
|
||||||
|
# There must have been at least 2 evictions, meaning if we calculate
|
||||||
|
# all 12 values again, we must get called at least 2 more times
|
||||||
|
for k in range(0, 12):
|
||||||
|
yield a.func(k)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prefill(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
d = defer.succeed(123)
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return d
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
|
||||||
|
a.func.prefill(("foo",), 456)
|
||||||
|
|
||||||
|
self.assertEquals(a.func("foo").result, 456)
|
||||||
|
self.assertEquals(callcount[0], 0)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate_context(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_eviction_context(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached(max_entries=2)
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func2("foo")
|
||||||
|
yield a.func2("foo2")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func("foo3")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 3)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 4)
|
||||||
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_double_get(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
|
||||||
|
|
||||||
class CachedListDescriptorTestCase(unittest.TestCase):
|
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_cache(self):
|
def test_cache(self):
|
||||||
|
|
Loading…
Reference in New Issue