WIP fast path state
This commit is contained in:
parent
01ccc9e6f2
commit
cfc2169b31
|
@ -676,3 +676,32 @@ def auth_types_for_event(event):
|
|||
auth_types.append(key)
|
||||
|
||||
return auth_types
|
||||
|
||||
|
||||
def filter_dependent_state(keys, state):
|
||||
if (EventTypes.Create, "") in keys:
|
||||
return state
|
||||
|
||||
if (EventTypes.PowerLevels, "") in keys:
|
||||
return state
|
||||
|
||||
def _filter_state(entry):
|
||||
etype, state_key, sender = entry
|
||||
|
||||
if (etype, state_key) in keys:
|
||||
return True
|
||||
|
||||
if (EventTypes.Member, sender) in keys:
|
||||
return True
|
||||
|
||||
if etype == EventTypes.Member:
|
||||
if (EventTypes.JoinRules, "") in keys:
|
||||
return True
|
||||
|
||||
for t, _ in keys:
|
||||
if t == EventTypes.ThirdPartyInvite:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return filter(_filter_state, state)
|
||||
|
|
117
synapse/state.py
117
synapse/state.py
|
@ -55,9 +55,9 @@ def _gen_state_id():
|
|||
|
||||
|
||||
class _StateCacheEntry(object):
|
||||
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
|
||||
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids", "conflicted_state"]
|
||||
|
||||
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
|
||||
def __init__(self, state, state_group, prev_group=None, delta_ids=None, conflicted_state=None):
|
||||
# dict[(str, str), str] map from (type, state_key) to event_id
|
||||
self.state = frozendict(state)
|
||||
|
||||
|
@ -80,6 +80,8 @@ class _StateCacheEntry(object):
|
|||
else:
|
||||
self.state_id = _gen_state_id()
|
||||
|
||||
self.conflicted_state = conflicted_state
|
||||
|
||||
def __len__(self):
|
||||
return len(self.state)
|
||||
|
||||
|
@ -375,7 +377,7 @@ class StateHandler(object):
|
|||
}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||
new_state, _, _ = resolve_events_with_state_map(state_set_ids, state_map)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in new_state.iteritems()
|
||||
|
@ -462,18 +464,17 @@ class StateResolutionHandler(object):
|
|||
for key, e_id in st.iteritems():
|
||||
state.setdefault(key, set()).add(e_id)
|
||||
|
||||
# build a map from state key to the event_ids which set that state,
|
||||
# including only those where there are state keys in conflict.
|
||||
conflicted_state = {
|
||||
k: list(v)
|
||||
for k, v in state.iteritems()
|
||||
if len(v) > 1
|
||||
}
|
||||
# Check that there is a conflict between the state groups
|
||||
conflicted_state = False
|
||||
for values in state.itervalues():
|
||||
if len(values):
|
||||
conflicted_state = True
|
||||
break
|
||||
|
||||
if conflicted_state:
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_factory(
|
||||
new_state, _, conflicted_state_keys = yield resolve_events_with_factory(
|
||||
state_groups_ids.values(),
|
||||
event_map=event_map,
|
||||
state_map_factory=state_map_factory,
|
||||
|
@ -482,6 +483,7 @@ class StateResolutionHandler(object):
|
|||
new_state = {
|
||||
key: e_ids.pop() for key, e_ids in state.iteritems()
|
||||
}
|
||||
conflicted_state_keys = []
|
||||
|
||||
with Measure(self.clock, "state.create_group_ids"):
|
||||
# if the new state matches any of the input state groups, we can
|
||||
|
@ -517,6 +519,7 @@ class StateResolutionHandler(object):
|
|||
state_group=state_group,
|
||||
prev_group=prev_group,
|
||||
delta_ids=delta_ids,
|
||||
conflicted_state=conflicted_state_keys,
|
||||
)
|
||||
|
||||
if self._state_cache is not None:
|
||||
|
@ -524,6 +527,82 @@ class StateResolutionHandler(object):
|
|||
|
||||
defer.returnValue(cache)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store,
|
||||
event_map, state_map_factory):
|
||||
deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas)
|
||||
|
||||
to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state)
|
||||
|
||||
new_groups = list(unchanged_state_groups)
|
||||
new_groups.extend(g for _, g, _ in changed_groups)
|
||||
|
||||
types = [(etype, state_key) for etype, state_key, _ in to_recalculate]
|
||||
state_sets = yield store._get_state_groups_from_groups(new_groups, types=types)
|
||||
|
||||
state_sets = {
|
||||
sg: {
|
||||
key: cs[key]
|
||||
for key in types
|
||||
}
|
||||
for sg, cs in state_sets.iteritems()
|
||||
}
|
||||
|
||||
logger.info("Recalculating: %s", to_recalculate)
|
||||
|
||||
group_names = frozenset(new_groups)
|
||||
with (yield self.resolve_linearizer.queue(group_names)):
|
||||
cache = None
|
||||
if self._state_cache is not None:
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
|
||||
if cache:
|
||||
new_state = {
|
||||
(etype, state_key): cache.state[(etype, state_key)]
|
||||
for etype, state_key, _ in to_recalculate
|
||||
}
|
||||
unconflicted_state, _ = _seperate(
|
||||
state_sets.values(),
|
||||
)
|
||||
else:
|
||||
needed_events = set(
|
||||
event_id
|
||||
for state in state_sets.itervalues()
|
||||
for event_id in state.itervalues()
|
||||
)
|
||||
if event_map is not None:
|
||||
needed_events -= set(event_map.iterkeys())
|
||||
|
||||
# logger.info("state_sets: %s", state_sets)
|
||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict (and those in event_map)
|
||||
state_map = yield state_map_factory(needed_events)
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
|
||||
_, cs = _seperate(state_sets.values())
|
||||
|
||||
needed_state = _get_auth_event_keys(cs, state_map)
|
||||
sg = new_groups[0]
|
||||
res = yield store._get_state_for_groups((sg,), types=needed_state)
|
||||
state = res[sg]
|
||||
for s in state_sets.itervalues():
|
||||
s.update(state)
|
||||
|
||||
logger.info("Added state: %s", len(needed_state))
|
||||
|
||||
new_state, unconflicted_state, _ = yield resolve_events_with_factory(
|
||||
state_sets.values(),
|
||||
event_map=state_map,
|
||||
state_map_factory=state_map_factory,
|
||||
)
|
||||
|
||||
conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state]
|
||||
|
||||
defer.returnValue((new_state, conflicted_state))
|
||||
|
||||
|
||||
def _ordered_events(events):
|
||||
def key_func(e):
|
||||
|
@ -677,6 +756,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
|||
))
|
||||
|
||||
|
||||
def _get_auth_event_keys(conflicted_state, state_map):
|
||||
auth_events = set()
|
||||
for event_ids in conflicted_state.itervalues():
|
||||
for event_id in event_ids:
|
||||
if event_id in state_map:
|
||||
auth_events.update(event_auth.auth_types_for_event(state_map[event_id]))
|
||||
return auth_events
|
||||
|
||||
|
||||
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
|
||||
auth_events = {}
|
||||
for event_ids in conflicted_state.itervalues():
|
||||
|
@ -719,7 +807,12 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
|
|||
for key, event in resolved_state.iteritems():
|
||||
new_state[key] = event.event_id
|
||||
|
||||
return new_state
|
||||
return new_state, unconflicted_state_ids, [
|
||||
(key[0], key[1], event.sender)
|
||||
for key, evs in conflicted_state.iteritems()
|
||||
if len(evs) > 1
|
||||
for event in evs
|
||||
]
|
||||
|
||||
|
||||
def _resolve_state_events(conflicted_state, auth_events):
|
||||
|
|
Loading…
Reference in New Issue