diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5588c9e697..a04731ae11 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached, cachedInlineCallbacks +from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList from twisted.internet import defer -from synapse.util import unwrapFirstError from synapse.util.stringutils import random_string import logging @@ -50,32 +49,20 @@ class StateStore(SQLBaseStore): The return value is a dict mapping group names to lists of events. """ + if not event_ids: + defer.returnValue({}) - event_and_groups = yield defer.gatherResults( - [ - self._get_state_group_for_event( - room_id, event_id, - ).addCallback(lambda group, event_id: (event_id, group), event_id) - for event_id in event_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) - groups = set(group for _, group in event_and_groups if group) + groups = set(event_to_groups.values()) - group_to_state = yield defer.gatherResults( - [ - self._get_state_for_group( - group, - ).addCallback(lambda state_dict, group: (group, state_dict), group) - for group in groups - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + group_to_state = yield self._get_state_for_groups(groups) defer.returnValue({ group: state_map.values() - for group, state_map in group_to_state + for group, state_map in group_to_state.items() }) @cached(num_args=1) @@ -212,17 +199,48 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) - return group, [ - r[0] - for r in txn.fetchall() - ] + return [r[0] for r in txn.fetchall()] return self.runInteraction( "_get_state_groups_from_group", f, ) - @cached(num_args=3, lru=True, max_entries=20000) + def _get_state_groups_from_groups(self, groups_and_types): + def f(txn): + results = {} + for group, types in groups_and_types: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + results[group] = [ + r[0] + for r in txn.fetchall() + ] + + return results + + return self.runInteraction( + "_get_state_groups_from_groups", + f, + ) + + @cached(num_args=3, lru=True, max_entries=10000) def _get_state_for_event_id(self, room_id, event_id, types): def f(txn): type_and_state_sql = " OR ".join([ @@ -274,33 +292,19 @@ class StateStore(SQLBaseStore): deferred: A list of dicts corresponding to the event_ids given. The dicts are mappings from (type, state_key) -> state_events """ - event_and_groups = yield defer.gatherResults( - [ - self._get_state_group_for_event( - room_id, event_id, - ).addCallback(lambda group, event_id: (event_id, group), event_id) - for event_id in event_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) - groups = set(group for _, group in event_and_groups) + groups = set(event_to_groups.values()) - res = yield defer.gatherResults( - [ - self._get_state_for_group( - group, types - ).addCallback(lambda state_dict, group: (group, state_dict), group) - for group in groups - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - group_to_state = dict(res) + group_to_state = yield self._get_state_for_groups( + groups, types + ) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_and_groups + for event_id, group in event_to_groups.items() } defer.returnValue([ @@ -320,8 +324,29 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @defer.inlineCallbacks - def _get_state_for_group(self, group, types=None): + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", num_args=2) + def _get_state_group_for_events(self, room_id, event_ids): + def f(txn): + results = {} + for event_id in event_ids: + results[event_id] = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + ) + + return results + + return self.runInteraction( + "_get_state_group_for_events", + f, + ) + + def _get_state_for_group_from_cache(self, group, types=None): is_all, state_dict = self._state_group_cache.get(group) type_to_key = {} @@ -339,7 +364,7 @@ class StateStore(SQLBaseStore): missing_types.add((typ, state_key)) if is_all and types is None: - defer.returnValue(state_dict) + return state_dict, missing_types if is_all or (types is not None and not missing_types): sentinel = object() @@ -354,41 +379,81 @@ class StateStore(SQLBaseStore): return True return False - defer.returnValue({ + return { k: v for k, v in state_dict.items() if v and include(k[0], k[1]) - }) + }, missing_types + + return {}, missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, types=None): + results = {} + missing_groups_and_types = [] + for group in groups: + state_dict, missing_types = self._get_state_for_group_from_cache( + group, types + ) + + if types is not None and not missing_types: + results[group] = { + key: value + for key, value in state_dict.items() + if value + } + else: + missing_groups_and_types.append(( + group, + missing_types if types else None + )) + + if not missing_groups_and_types: + defer.returnValue(results) # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence - _, state_ids = yield self._get_state_groups_from_group( - group, - frozenset(missing_types) if types else None + + group_state_dict = yield self._get_state_groups_from_groups( + missing_groups_and_types ) - state_events = yield self._get_events(state_ids, get_prev_content=False) - state_dict = { - key: None - for key in missing_types - } - state_dict.update({ - (e.type, e.state_key): e + + state_events = yield self._get_events( + [e_id for l in group_state_dict.values() for e_id in l], + get_prev_content=False + ) + + state_events = { + e.event_id: e for e in state_events - }) + } - # Update the cache - self._state_group_cache.update( - cache_seq_num, - key=group, - value=state_dict, - full=(types is None), - ) + for group, state_ids in group_state_dict.items(): + state_dict = { + key: None + for key in missing_types + } + evs = [state_events[e_id] for e_id in state_ids] + state_dict.update({ + (e.type, e.state_key): e + for e in evs + }) - defer.returnValue({ - key: value - for key, value in state_dict.items() - if value - }) + # Update the cache + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + results[group] = { + key: value + for key, value in state_dict.items() + if value + } + + defer.returnValue(results) def _make_group_id(clock):