Type annotations for LruCache (#8562)
* type annotations for LruCache * changelog * Apply suggestions from code review Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com> * review comments Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
commit
d6094176d1
|
@ -0,0 +1 @@
|
|||
Add type annotations for `LruCache`.
|
|
@ -69,7 +69,9 @@ class Auth:
|
|||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
self.token_cache = LruCache(10000, "token_cache")
|
||||
self.token_cache = LruCache(
|
||||
10000, "token_cache"
|
||||
) # type: LruCache[str, Tuple[str, bool]]
|
||||
|
||||
self._auth_blocking = AuthBlocking(self.hs)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Pattern, Union
|
||||
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import UserID
|
||||
|
@ -173,19 +173,21 @@ class PushRuleEvaluatorForEvent:
|
|||
# Similar to _glob_matches, but do not treat display_name as a glob.
|
||||
r = regex_cache.get((display_name, False, True), None)
|
||||
if not r:
|
||||
r = re.escape(display_name)
|
||||
r = _re_word_boundary(r)
|
||||
r = re.compile(r, flags=re.IGNORECASE)
|
||||
r1 = re.escape(display_name)
|
||||
r1 = _re_word_boundary(r1)
|
||||
r = re.compile(r1, flags=re.IGNORECASE)
|
||||
regex_cache[(display_name, False, True)] = r
|
||||
|
||||
return r.search(body)
|
||||
return bool(r.search(body))
|
||||
|
||||
def _get_value(self, dotted_key: str) -> Optional[str]:
|
||||
return self._value_cache.get(dotted_key, None)
|
||||
|
||||
|
||||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
||||
regex_cache = LruCache(50000, "regex_push_cache")
|
||||
regex_cache = LruCache(
|
||||
50000, "regex_push_cache"
|
||||
) # type: LruCache[Tuple[str, bool, bool], Pattern]
|
||||
|
||||
|
||||
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
||||
|
@ -203,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
|||
if not r:
|
||||
r = _glob_to_re(glob, word_boundary)
|
||||
regex_cache[(glob, True, word_boundary)] = r
|
||||
return r.search(value)
|
||||
return bool(r.search(value))
|
||||
except re.error:
|
||||
logger.warning("Failed to parse glob to regex: %r", glob)
|
||||
return False
|
||||
|
|
|
@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]):
|
|||
size_callback=(lambda d: len(d)) if iterable else None,
|
||||
metrics_collection_callback=metrics_cb,
|
||||
apply_cache_factor_from_config=apply_cache_factor_from_config,
|
||||
)
|
||||
) # type: LruCache[KT, VT]
|
||||
|
||||
self.thread = None # type: Optional[threading.Thread]
|
||||
|
||||
|
@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]):
|
|||
self.check_thread()
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
|
||||
key = cast(KT, key)
|
||||
self.cache.del_multi(key)
|
||||
|
||||
# if we have a pending lookup for this key, remove it from the
|
||||
# _pending_deferred_cache, as above
|
||||
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
|
||||
entry_dict = self._pending_deferred_cache.pop(key, None)
|
||||
if entry_dict is not None:
|
||||
for entry in iterate_tree_cache_entry(entry_dict):
|
||||
entry.invalidate()
|
||||
|
|
|
@ -12,10 +12,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import Any
|
||||
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
|
@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
|
|||
return len(self.value)
|
||||
|
||||
|
||||
class _Sentinel(enum.Enum):
|
||||
# defining a sentinel in this way allows mypy to correctly handle the
|
||||
# type of a dictionary lookup.
|
||||
sentinel = object()
|
||||
|
||||
|
||||
class DictionaryCache:
|
||||
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
|
||||
fetching a subset of dictionary keys for a particular key.
|
||||
"""
|
||||
|
||||
def __init__(self, name, max_entries=1000):
|
||||
self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len)
|
||||
self.cache = LruCache(
|
||||
max_size=max_entries, cache_name=name, size_callback=len
|
||||
) # type: LruCache[Any, DictionaryEntry]
|
||||
|
||||
self.name = name
|
||||
self.sequence = 0
|
||||
self.thread = None
|
||||
|
||||
class Sentinel:
|
||||
__slots__ = []
|
||||
|
||||
self.sentinel = Sentinel()
|
||||
|
||||
def check_thread(self):
|
||||
expected_thread = self.thread
|
||||
if expected_thread is None:
|
||||
|
@ -76,8 +80,8 @@ class DictionaryCache:
|
|||
Returns:
|
||||
DictionaryEntry
|
||||
"""
|
||||
entry = self.cache.get(key, self.sentinel)
|
||||
if entry is not self.sentinel:
|
||||
entry = self.cache.get(key, _Sentinel.sentinel)
|
||||
if entry is not _Sentinel.sentinel:
|
||||
if dict_keys is None:
|
||||
return DictionaryEntry(
|
||||
entry.full, entry.known_absent, dict(entry.value)
|
||||
|
|
|
@ -15,12 +15,35 @@
|
|||
|
||||
import threading
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.config import cache as cache_config
|
||||
from synapse.util.caches import CacheMetric, register_cache
|
||||
from synapse.util.caches.treecache import TreeCache
|
||||
|
||||
# Function type: the type used for invalidation callbacks
|
||||
FT = TypeVar("FT", bound=Callable[..., Any])
|
||||
|
||||
# Key and Value type for the cache
|
||||
KT = TypeVar("KT")
|
||||
VT = TypeVar("VT")
|
||||
|
||||
# a general type var, distinct from either KT or VT
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def enumerate_leaves(node, depth):
|
||||
if depth == 0:
|
||||
|
@ -42,7 +65,7 @@ class _Node:
|
|||
self.callbacks = callbacks
|
||||
|
||||
|
||||
class LruCache:
|
||||
class LruCache(Generic[KT, VT]):
|
||||
"""
|
||||
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
|
||||
|
||||
|
@ -128,13 +151,13 @@ class LruCache:
|
|||
if metrics:
|
||||
metrics.inc_evictions(evicted_len)
|
||||
|
||||
def synchronized(f):
|
||||
def synchronized(f: FT) -> FT:
|
||||
@wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
with lock:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
return cast(FT, inner)
|
||||
|
||||
cached_cache_len = [0]
|
||||
if size_callback is not None:
|
||||
|
@ -188,8 +211,31 @@ class LruCache:
|
|||
node.callbacks.clear()
|
||||
return deleted_len
|
||||
|
||||
@overload
|
||||
def cache_get(
|
||||
key: KT,
|
||||
default: Literal[None] = None,
|
||||
callbacks: Iterable[Callable[[], None]] = ...,
|
||||
update_metrics: bool = ...,
|
||||
) -> Optional[VT]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cache_get(
|
||||
key: KT,
|
||||
default: T,
|
||||
callbacks: Iterable[Callable[[], None]] = ...,
|
||||
update_metrics: bool = ...,
|
||||
) -> Union[T, VT]:
|
||||
...
|
||||
|
||||
@synchronized
|
||||
def cache_get(key, default=None, callbacks=[], update_metrics=True):
|
||||
def cache_get(
|
||||
key: KT,
|
||||
default: Optional[T] = None,
|
||||
callbacks: Iterable[Callable[[], None]] = [],
|
||||
update_metrics: bool = True,
|
||||
):
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
move_node_to_front(node)
|
||||
|
@ -203,7 +249,7 @@ class LruCache:
|
|||
return default
|
||||
|
||||
@synchronized
|
||||
def cache_set(key, value, callbacks=[]):
|
||||
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
# We sometimes store large objects, e.g. dicts, which cause
|
||||
|
@ -232,7 +278,7 @@ class LruCache:
|
|||
evict()
|
||||
|
||||
@synchronized
|
||||
def cache_set_default(key, value):
|
||||
def cache_set_default(key: KT, value: VT) -> VT:
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
return node.value
|
||||
|
@ -241,8 +287,16 @@ class LruCache:
|
|||
evict()
|
||||
return value
|
||||
|
||||
@overload
|
||||
def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cache_pop(key: KT, default: T) -> Union[T, VT]:
|
||||
...
|
||||
|
||||
@synchronized
|
||||
def cache_pop(key, default=None):
|
||||
def cache_pop(key: KT, default: Optional[T] = None):
|
||||
node = cache.get(key, None)
|
||||
if node:
|
||||
delete_node(node)
|
||||
|
@ -252,18 +306,18 @@ class LruCache:
|
|||
return default
|
||||
|
||||
@synchronized
|
||||
def cache_del_multi(key):
|
||||
def cache_del_multi(key: KT) -> None:
|
||||
"""
|
||||
This will only work if constructed with cache_type=TreeCache
|
||||
"""
|
||||
popped = cache.pop(key)
|
||||
if popped is None:
|
||||
return
|
||||
for leaf in enumerate_leaves(popped, keylen - len(key)):
|
||||
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
|
||||
delete_node(leaf)
|
||||
|
||||
@synchronized
|
||||
def cache_clear():
|
||||
def cache_clear() -> None:
|
||||
list_root.next_node = list_root
|
||||
list_root.prev_node = list_root
|
||||
for node in cache.values():
|
||||
|
@ -274,7 +328,7 @@ class LruCache:
|
|||
cached_cache_len[0] = 0
|
||||
|
||||
@synchronized
|
||||
def cache_contains(key):
|
||||
def cache_contains(key: KT) -> bool:
|
||||
return key in cache
|
||||
|
||||
self.sentinel = object()
|
||||
|
|
Loading…
Reference in New Issue