Add some more type annotations to Cache
This commit is contained in:
parent
629a951b49
commit
7eff59ec91
|
@ -26,7 +26,7 @@ class SlavedClientIpStore(BaseSlavedStore):
|
|||
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen", keylen=4, max_entries=50000
|
||||
)
|
||||
) # type: Cache[tuple, int]
|
||||
|
||||
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
||||
now = int(self._clock.time_msec())
|
||||
|
|
|
@ -13,12 +13,23 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
@ -38,6 +49,8 @@ logger = logging.getLogger(__name__)
|
|||
CacheKey = Union[Tuple, Any]
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
KT = TypeVar("KT")
|
||||
VT = TypeVar("VT")
|
||||
|
||||
|
||||
class _CachedFunction(Generic[F]):
|
||||
|
@ -61,13 +74,19 @@ cache_pending_metric = Gauge(
|
|||
["name"],
|
||||
)
|
||||
|
||||
_CacheSentinel = object()
|
||||
|
||||
class _Sentinel(enum.Enum):
|
||||
# defining a sentinel in this way allows mypy to correctly handle the
|
||||
# type of a dictionary lookup.
|
||||
sentinel = object()
|
||||
|
||||
|
||||
class CacheEntry:
|
||||
__slots__ = ["deferred", "callbacks", "invalidated"]
|
||||
|
||||
def __init__(self, deferred, callbacks):
|
||||
def __init__(
|
||||
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
|
||||
):
|
||||
self.deferred = deferred
|
||||
self.callbacks = set(callbacks)
|
||||
self.invalidated = False
|
||||
|
@ -80,7 +99,13 @@ class CacheEntry:
|
|||
self.callbacks.clear()
|
||||
|
||||
|
||||
class Cache:
|
||||
class Cache(Generic[KT, VT]):
|
||||
"""Wraps an LruCache, adding support for Deferred results.
|
||||
|
||||
It expects that each entry added with set() will be a Deferred; likewise get()
|
||||
may return an ObservableDeferred.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"cache",
|
||||
"name",
|
||||
|
@ -103,19 +128,23 @@ class Cache:
|
|||
Args:
|
||||
name: The name of the cache
|
||||
max_entries: Maximum amount of entries that the cache will hold
|
||||
keylen: The length of the tuple used as the cache key
|
||||
keylen: The length of the tuple used as the cache key. Ignored unless
|
||||
`tree` is True.
|
||||
tree: Use a TreeCache instead of a dict as the underlying cache type
|
||||
iterable: If True, count each item in the cached object as an entry,
|
||||
rather than each cached object
|
||||
apply_cache_factor_from_config: Whether cache factors specified in the
|
||||
config file affect `max_entries`
|
||||
|
||||
Returns:
|
||||
Cache
|
||||
"""
|
||||
cache_type = TreeCache if tree else dict
|
||||
self._pending_deferred_cache = cache_type()
|
||||
|
||||
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
|
||||
self._pending_deferred_cache = (
|
||||
cache_type()
|
||||
) # type: MutableMapping[KT, CacheEntry]
|
||||
|
||||
# cache is used for completed results and maps to the result itself, rather than
|
||||
# a Deferred.
|
||||
self.cache = LruCache(
|
||||
max_size=max_entries,
|
||||
keylen=keylen,
|
||||
|
@ -155,7 +184,13 @@ class Cache:
|
|||
"Cache objects can only be accessed from the main thread"
|
||||
)
|
||||
|
||||
def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
|
||||
def get(
|
||||
self,
|
||||
key: KT,
|
||||
default=_Sentinel.sentinel,
|
||||
callback: Optional[Callable[[], None]] = None,
|
||||
update_metrics: bool = True,
|
||||
):
|
||||
"""Looks the key up in the caches.
|
||||
|
||||
Args:
|
||||
|
@ -166,30 +201,32 @@ class Cache:
|
|||
update_metrics (bool): whether to update the cache hit rate metrics
|
||||
|
||||
Returns:
|
||||
Either an ObservableDeferred or the raw result
|
||||
Either an ObservableDeferred or the result itself
|
||||
"""
|
||||
callbacks = [callback] if callback else []
|
||||
val = self._pending_deferred_cache.get(key, _CacheSentinel)
|
||||
if val is not _CacheSentinel:
|
||||
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
|
||||
if val is not _Sentinel.sentinel:
|
||||
val.callbacks.update(callbacks)
|
||||
if update_metrics:
|
||||
self.metrics.inc_hits()
|
||||
return val.deferred
|
||||
|
||||
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
|
||||
if val is not _CacheSentinel:
|
||||
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
|
||||
if val is not _Sentinel.sentinel:
|
||||
self.metrics.inc_hits()
|
||||
return val
|
||||
|
||||
if update_metrics:
|
||||
self.metrics.inc_misses()
|
||||
|
||||
if default is _CacheSentinel:
|
||||
if default is _Sentinel.sentinel:
|
||||
raise KeyError()
|
||||
else:
|
||||
return default
|
||||
|
||||
def set(self, key, value, callback=None):
|
||||
def set(
|
||||
self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None
|
||||
) -> ObservableDeferred:
|
||||
if not isinstance(value, defer.Deferred):
|
||||
raise TypeError("not a Deferred")
|
||||
|
||||
|
@ -248,7 +285,7 @@ class Cache:
|
|||
observer.addCallbacks(cb, eb)
|
||||
return observable
|
||||
|
||||
def prefill(self, key, value, callback=None):
|
||||
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
|
||||
callbacks = [callback] if callback else []
|
||||
self.cache.set(key, value, callbacks=callbacks)
|
||||
|
||||
|
@ -267,7 +304,7 @@ class Cache:
|
|||
if entry:
|
||||
entry.invalidate()
|
||||
|
||||
def invalidate_many(self, key):
|
||||
def invalidate_many(self, key: KT):
|
||||
self.check_thread()
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
|
||||
|
@ -275,7 +312,7 @@ class Cache:
|
|||
|
||||
# if we have a pending lookup for this key, remove it from the
|
||||
# _pending_deferred_cache, as above
|
||||
entry_dict = self._pending_deferred_cache.pop(key, None)
|
||||
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
|
||||
if entry_dict is not None:
|
||||
for entry in iterate_tree_cache_entry(entry_dict):
|
||||
entry.invalidate()
|
||||
|
@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
keylen=self.num_args,
|
||||
tree=self.tree,
|
||||
iterable=self.iterable,
|
||||
)
|
||||
) # type: Cache[Tuple, Any]
|
||||
|
||||
def get_cache_key_gen(args, kwargs):
|
||||
"""Given some args/kwargs return a generator that resolves into
|
||||
|
|
|
@ -64,7 +64,8 @@ class LruCache:
|
|||
Args:
|
||||
max_size: The maximum amount of entries the cache can hold
|
||||
|
||||
keylen: The length of the tuple used as the cache key
|
||||
keylen: The length of the tuple used as the cache key. Ignored unless
|
||||
cache_type is `TreeCache`.
|
||||
|
||||
cache_type (type):
|
||||
type of underlying cache to be used. Typically one of dict
|
||||
|
|
Loading…
Reference in New Issue