WIP fast path state

This commit is contained in:
Erik Johnston 2018-03-28 16:19:26 +01:00
parent 01ccc9e6f2
commit cfc2169b31
2 changed files with 134 additions and 12 deletions

View File

@ -676,3 +676,32 @@ def auth_types_for_event(event):
auth_types.append(key)
return auth_types
def filter_dependent_state(keys, state):
if (EventTypes.Create, "") in keys:
return state
if (EventTypes.PowerLevels, "") in keys:
return state
def _filter_state(entry):
etype, state_key, sender = entry
if (etype, state_key) in keys:
return True
if (EventTypes.Member, sender) in keys:
return True
if etype == EventTypes.Member:
if (EventTypes.JoinRules, "") in keys:
return True
for t, _ in keys:
if t == EventTypes.ThirdPartyInvite:
return True
return False
return filter(_filter_state, state)

View File

@ -55,9 +55,9 @@ def _gen_state_id():
class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids", "conflicted_state"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
def __init__(self, state, state_group, prev_group=None, delta_ids=None, conflicted_state=None):
# dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state)
@ -80,6 +80,8 @@ class _StateCacheEntry(object):
else:
self.state_id = _gen_state_id()
self.conflicted_state = conflicted_state
def __len__(self):
return len(self.state)
@ -375,7 +377,7 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state, _, _ = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.iteritems()
@ -462,18 +464,17 @@ class StateResolutionHandler(object):
for key, e_id in st.iteritems():
state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict.
conflicted_state = {
k: list(v)
for k, v in state.iteritems()
if len(v) > 1
}
# Check that there is a conflict between the state groups
conflicted_state = False
for values in state.itervalues():
if len(values):
conflicted_state = True
break
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
new_state, _, conflicted_state_keys = yield resolve_events_with_factory(
state_groups_ids.values(),
event_map=event_map,
state_map_factory=state_map_factory,
@ -482,6 +483,7 @@ class StateResolutionHandler(object):
new_state = {
key: e_ids.pop() for key, e_ids in state.iteritems()
}
conflicted_state_keys = []
with Measure(self.clock, "state.create_group_ids"):
# if the new state matches any of the input state groups, we can
@ -517,6 +519,7 @@ class StateResolutionHandler(object):
state_group=state_group,
prev_group=prev_group,
delta_ids=delta_ids,
conflicted_state=conflicted_state_keys,
)
if self._state_cache is not None:
@ -524,6 +527,82 @@ class StateResolutionHandler(object):
defer.returnValue(cache)
@defer.inlineCallbacks
def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store,
event_map, state_map_factory):
deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas)
to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state)
new_groups = list(unchanged_state_groups)
new_groups.extend(g for _, g, _ in changed_groups)
types = [(etype, state_key) for etype, state_key, _ in to_recalculate]
state_sets = yield store._get_state_groups_from_groups(new_groups, types=types)
state_sets = {
sg: {
key: cs[key]
for key in types
}
for sg, cs in state_sets.iteritems()
}
logger.info("Recalculating: %s", to_recalculate)
group_names = frozenset(new_groups)
with (yield self.resolve_linearizer.queue(group_names)):
cache = None
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
new_state = {
(etype, state_key): cache.state[(etype, state_key)]
for etype, state_key, _ in to_recalculate
}
unconflicted_state, _ = _seperate(
state_sets.values(),
)
else:
needed_events = set(
event_id
for state in state_sets.itervalues()
for event_id in state.itervalues()
)
if event_map is not None:
needed_events -= set(event_map.iterkeys())
# logger.info("state_sets: %s", state_sets)
logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
_, cs = _seperate(state_sets.values())
needed_state = _get_auth_event_keys(cs, state_map)
sg = new_groups[0]
res = yield store._get_state_for_groups((sg,), types=needed_state)
state = res[sg]
for s in state_sets.itervalues():
s.update(state)
logger.info("Added state: %s", len(needed_state))
new_state, unconflicted_state, _ = yield resolve_events_with_factory(
state_sets.values(),
event_map=state_map,
state_map_factory=state_map_factory,
)
conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state]
defer.returnValue((new_state, conflicted_state))
def _ordered_events(events):
def key_func(e):
@ -677,6 +756,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
))
def _get_auth_event_keys(conflicted_state, state_map):
auth_events = set()
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
auth_events.update(event_auth.auth_types_for_event(state_map[event_id]))
return auth_events
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
@ -719,7 +807,12 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state
return new_state, unconflicted_state_ids, [
(key[0], key[1], event.sender)
for key, evs in conflicted_state.iteritems()
if len(evs) > 1
for event in evs
]
def _resolve_state_events(conflicted_state, auth_events):