Fix a few race conditions in the state calculation

Be a bit more careful about how we calculate the state to be returned by
/sync. In a few places, it was possible for /sync to return slightly later
state than that represented by the next_batch token and the timeline. In
particular, the following cases were susceptible:

* On a full state sync, for an active room
* During a per-room incremental sync with a timeline gap
* When the user has just joined a room. (Refactor check_joined_room to make it
  less magical)

Also, use store.get_state_for_events() (and thus the existing stategroups) to
calculate the state corresponding to a particular sync position, rather than
state_handler.compute_event_context(), which recalculates from first principles
(and tends to miss some state).

Merged from PR https://github.com/matrix-org/synapse/pull/372
This commit is contained in:
Richard van der Hoff 2015-11-10 18:27:23 +00:00
parent 5ab4b0afe8
commit fddedd51d9
2 changed files with 78 additions and 61 deletions

View File

@ -254,9 +254,7 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token room_id, sync_config, now_token, since_token=timeline_since_token
) )
current_state = yield self.state_handler.get_current_state( current_state = yield self.get_state_at(room_id, now_token)
room_id
)
defer.returnValue(JoinedSyncResult( defer.returnValue(JoinedSyncResult(
room_id=room_id, room_id=room_id,
@ -353,14 +351,12 @@ class SyncHandler(BaseHandler):
room_id, sync_config, leave_token, since_token=timeline_since_token room_id, sync_config, leave_token, since_token=timeline_since_token
) )
leave_state = yield self.store.get_state_for_events( leave_state = yield self.store.get_state_for_event(leave_event_id)
[leave_event_id], None
)
defer.returnValue(ArchivedSyncResult( defer.returnValue(ArchivedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state[leave_event_id], state=leave_state,
private_user_data=self.private_user_data_for_room( private_user_data=self.private_user_data_for_room(
room_id, tags_by_room room_id, tags_by_room
), ),
@ -424,6 +420,9 @@ class SyncHandler(BaseHandler):
if len(room_events) <= timeline_limit: if len(room_events) <= timeline_limit:
# There is no gap in any of the rooms. Therefore we can just # There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them. # partition the new events by room and return them.
logger.debug("Got %i events for incremental sync - not limited",
len(room_events))
invite_events = [] invite_events = []
leave_events = [] leave_events = []
events_by_room_id = {} events_by_room_id = {}
@ -439,9 +438,11 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) recents = events_by_room_id.get(room_id, [])
logger.debug("Events for room %s: %r", room_id, recents)
state = { state = {
(event.type, event.state_key): event (event.type, event.state_key): event
for event in recents if event.is_state()} for event in recents if event.is_state()}
limited = False
if recents: if recents:
prev_batch = now_token.copy_and_replace( prev_batch = now_token.copy_and_replace(
@ -450,9 +451,13 @@ class SyncHandler(BaseHandler):
else: else:
prev_batch = now_token prev_batch = now_token
state, limited = yield self.check_joined_room( just_joined = yield self.check_joined_room(sync_config, state)
sync_config, room_id, state if just_joined:
) logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
@ -467,10 +472,15 @@ class SyncHandler(BaseHandler):
room_id, tags_by_room room_id, tags_by_room
), ),
) )
logger.debug("Result for room %s: %r", room_id, room_sync)
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
else: else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user( invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -563,6 +573,8 @@ class SyncHandler(BaseHandler):
Returns: Returns:
A Deferred JoinedSyncResult A Deferred JoinedSyncResult
""" """
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed. # TODO(mjark): Check for redactions we might have missed.
@ -572,30 +584,26 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch) logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a current_state = yield self.get_state_at(room_id, now_token)
# token to indicate what point in the stream this is
current_state = yield self.state_handler.get_current_state( state_at_previous_sync = yield self.get_state_at(
room_id room_id, stream_position=since_token
) )
state_at_previous_sync = yield self.get_state_at_previous_sync( state = yield self.compute_state_delta(
room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token, since_token=since_token,
previous_state=state_at_previous_sync, previous_state=state_at_previous_sync,
current_state=current_state, current_state=current_state,
) )
state_events_delta, _ = yield self.check_joined_room( just_joined = yield self.check_joined_room(sync_config, state)
sync_config, room_id, state_events_delta if just_joined:
) state = yield self.get_state_at(room_id, now_token)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( private_user_data=self.private_user_data_for_room(
room_id, tags_by_room room_id, tags_by_room
@ -627,16 +635,12 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch) logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a state_events_at_leave = yield self.store.get_state_for_event(
# token to indicate what point in the stream this is leave_event.event_id
leave_state = yield self.store.get_state_for_events(
[leave_event.event_id], None
) )
state_events_at_leave = leave_state[leave_event.event_id] state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
state_at_previous_sync = yield self.get_state_at_previous_sync(
leave_event.room_id, since_token=since_token
) )
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
@ -659,26 +663,36 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token): def get_state_after_event(self, event):
""" Get the room state at the previous sync the client made. """
Returns: Get the room state after the given event
A Deferred map from ((type, state_key)->Event)
:param synapse.events.EventBase event: event of interest
:return: A Deferred map from ((type, state_key)->Event)
"""
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
state = state.copy()
state[(event.type, event.state_key)] = event
defer.returnValue(state)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state
:returns: A Deferred map from ((type, state_key)->Event)
""" """
last_events, token = yield self.store.get_recent_events_for_room( last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=since_token.room_key, limit=1, room_id, end_token=stream_position.room_key, limit=1,
) )
if last_events: if last_events:
last_event = last_events[0] last_event = last_events[-1]
last_context = yield self.state_handler.compute_event_context( state = yield self.get_state_after_event(last_event)
last_event
)
if last_event.is_state():
state = last_context.current_state.copy()
state[(last_event.type, last_event.state_key)] = last_event
else:
state = last_context.current_state
else: else:
# no events in this room - so presumably no state
state = {} state = {}
defer.returnValue(state) defer.returnValue(state)
@ -706,31 +720,20 @@ class SyncHandler(BaseHandler):
state_delta[key] = event state_delta[key] = event
return state_delta return state_delta
@defer.inlineCallbacks def check_joined_room(self, sync_config, state_delta):
def check_joined_room(self, sync_config, room_id, state_delta):
""" """
Check if the user has just joined the given room. If so, return the Check if the user has just joined the given room (so should
full state for the room, instead of the delta since the last sync. be given the full state)
:param sync_config: :param sync_config:
:param room_id:
:param dict[(str,str), synapse.events.FrozenEvent] state_delta: the :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the
difference in state since the last sync difference in state since the last sync
:returns A deferred Tuple (state_delta, limited) :returns A deferred Tuple (state_delta, limited)
""" """
joined = False
limited = False
join_event = state_delta.get(( join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None) EventTypes.Member, sync_config.user.to_string()), None)
if join_event is not None: if join_event is not None:
if join_event.content["membership"] == Membership.JOIN: if join_event.content["membership"] == Membership.JOIN:
joined = True return True
return False
if joined:
state_delta = yield self.state_handler.get_current_state(room_id)
# the timeline is inherently limited if we've just joined
limited = True
defer.returnValue((state_delta, limited))

View File

@ -237,6 +237,20 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned
:param list[(str, str)]|None types: List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
:return: a deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000) @cached(num_args=2, lru=True, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(