diff --git a/synapse/state.py b/synapse/state.py index 2ffdb0b01e..1bf4a0df6f 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -531,24 +531,33 @@ class StateResolutionHandler(object): def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store, event_map, state_map_factory): deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas) + conflicted_state_keys = set((etype, state_key) for etype, state_key, _ in conflicted_state) + conflicted_state_keys.update(deltas) - to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state) + if not deltas: + logger.info("No deltas") + defer.returnValue({}, conflicted_state) + + to_recalculate = set(event_auth.filter_dependent_state(deltas, conflicted_state)) + to_recalculate = set((etype, state_key) for etype, state_key, _ in to_recalculate) + to_recalculate.update(deltas) new_groups = list(unchanged_state_groups) new_groups.extend(g for _, g, _ in changed_groups) - types = [(etype, state_key) for etype, state_key, _ in to_recalculate] - state_sets = yield store._get_state_groups_from_groups(new_groups, types=types) + state_sets = yield store._get_state_groups_from_groups(new_groups, types=to_recalculate) state_sets = { sg: { key: cs[key] - for key in types + for key in to_recalculate + if key in cs } for sg, cs in state_sets.iteritems() } logger.info("Recalculating: %s", to_recalculate) + logger.info("State Sets: %s", state_sets) group_names = frozenset(new_groups) with (yield self.resolve_linearizer.queue(group_names)): @@ -557,9 +566,10 @@ class StateResolutionHandler(object): cache = self._state_cache.get(group_names, None) if cache: + logger.info("Using cache") new_state = { - (etype, state_key): cache.state[(etype, state_key)] - for etype, state_key, _ in to_recalculate + key: cache.state[key] + for key in to_recalculate } unconflicted_state, _ = _seperate( state_sets.values(), @@ -585,13 +595,23 @@ class StateResolutionHandler(object): _, cs = _seperate(state_sets.values()) needed_state = _get_auth_event_keys(cs, state_map) - sg = new_groups[0] - res = yield store._get_state_for_groups((sg,), types=needed_state) - state = res[sg] - for s in state_sets.itervalues(): - s.update(state) - logger.info("Added state: %s", len(needed_state)) + if needed_state - conflicted_state_keys: + res = yield store._get_state_for_groups((new_groups[0],), types=(needed_state - conflicted_state_keys)) + state = res[new_groups[0]] + for s in state_sets.itervalues(): + s.update(state) + + logger.info("Added unconflicted state: %s", state) + + needed_conflicted_state = needed_state & conflicted_state_keys + if needed_conflicted_state: + for sg, s in state_sets.iteritems(): + res = yield store._get_state_for_groups((sg,), types=needed_conflicted_state) + state = res[sg] + s.update(state) + + logger.info("Added conflicted state: %s", needed_conflicted_state) new_state, unconflicted_state, _ = yield resolve_events_with_factory( state_sets.values(), @@ -601,6 +621,8 @@ class StateResolutionHandler(object): conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state] + logger.info("Returning: %s", new_state) + defer.returnValue((new_state, conflicted_state))