synapse-old/synapse/util/caches/treecache.py

114 lines
2.7 KiB
Python

from typing import Dict
SENTINEL = object()
class TreeCache:
"""
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
efficiently.
Keys must be tuples.
"""
def __init__(self):
self.size = 0
self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)
def __contains__(self, key):
return self.get(key, SENTINEL) is not SENTINEL
def set(self, key, value):
node = self.root
for k in key[:-1]:
node = node.setdefault(k, {})
node[key[-1]] = _Entry(value)
self.size += 1
def get(self, key, default=None):
node = self.root
for k in key[:-1]:
node = node.get(k, None)
if node is None:
return default
return node.get(key[-1], _Entry(default)).value
def clear(self):
self.size = 0
self.root = {}
def pop(self, key, default=None):
nodes = []
node = self.root
for k in key[:-1]:
node = node.get(k, None)
nodes.append(node) # don't add the root node
if node is None:
return default
popped = node.pop(key[-1], SENTINEL)
if popped is SENTINEL:
return default
node_and_keys = list(zip(nodes, key))
node_and_keys.reverse()
node_and_keys.append((self.root, None))
for i in range(len(node_and_keys) - 1):
n, k = node_and_keys[i]
if n:
break
node_and_keys[i + 1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt
return popped
def values(self):
return list(iterate_tree_cache_entry(self.root))
def __len__(self):
return self.size
def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, dict):
for value_d in d.values():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry:
__slots__ = ["value"]
def __init__(self, value):
self.value = value
def _strip_and_count_entires(d):
"""Takes an _Entry or dict with leaves of _Entry's, and either returns the
value or a dictionary with _Entry's replaced by their values.
Also returns the count of _Entry's
"""
if isinstance(d, dict):
cnt = 0
for key, value in d.items():
v, n = _strip_and_count_entires(value)
d[key] = v
cnt += n
return d, cnt
else:
return d.value, 1