Split _get_state_for_group_from_cache into two

This commit is contained in:
Erik Johnston 2015-08-12 17:06:21 +01:00
parent 7b0e797080
commit df361d08f7
1 changed files with 53 additions and 32 deletions

View File

@ -287,42 +287,39 @@ class StateStore(SQLBaseStore):
f,
)
def _get_state_for_group_from_cache(self, group, types=None):
def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
`missing_types` is the list of types that aren't in the cache for that
group, or None if `types` is None. `got_all` is a bool indicating if
we successfully retrieved all requests state from the cache, if False
we need to query the DB for the missing state.
group. `got_all` is a bool indicating if we successfully retrieved all
requests state from the cache, if False we need to query the DB for the
missing state.
Args:
group: The state group to lookup
types (list): List of 2-tuples of the form (`type`, `state_key`),
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
is_all, state_dict = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
if types is not None:
for typ, state_key in types:
if state_key is None:
type_to_key[typ] = None
for typ, state_key in types:
if state_key is None:
type_to_key[typ] = None
missing_types.add((typ, state_key))
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key))
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key))
if is_all:
missing_types = set()
if types is None:
return state_dict, set(), True
sentinel = object()
def include(typ, state_key):
if types is None:
return True
valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel:
return False
@ -340,6 +337,19 @@ class StateStore(SQLBaseStore):
if include(k[0], k[1])
}, missing_types, got_all
def _get_all_state_from_cache(self, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
indicating if we successfully retrieved all requests state from the
cache, if False we need to query the DB for the missing state.
Args:
group: The state group to lookup
"""
is_all, state_dict = self._state_group_cache.get(group)
return state_dict, is_all
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events
@ -349,18 +359,29 @@ class StateStore(SQLBaseStore):
"""
results = {}
missing_groups_and_types = []
for group in set(groups):
state_dict, missing_types, got_all = self._get_state_for_group_from_cache(
group, types
)
if types is not None:
for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((
group,
missing_types if types else None
))
if not got_all:
missing_groups_and_types.append((
group,
missing_types
))
else:
for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache(
group
)
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, None))
if not missing_groups_and_types:
defer.returnValue({