Fix race in sync when joining room

The race happens when the user joins a room at the same time as doing a
sync. We fetch the current token and then get the rooms the user is in.
If the join happens after the current token, but before we get the rooms
we end up sending down a partial room entry in the sync.

This is fixed by looking at the stream ordering of the membership
returned by get_rooms_for_user, and handling the case when that stream
ordering is after the current token.
This commit is contained in:
Erik Johnston 2018-03-05 12:06:19 +00:00
parent 8ffaacbee3
commit 8cb44da4aa
3 changed files with 102 additions and 30 deletions

View File

@ -235,10 +235,10 @@ class SyncHandler(object):
defer.returnValue(rules) defer.returnValue(rules)
@defer.inlineCallbacks @defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None): def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in """Get the ephemeral events for each room the user is in
Args: Args:
sync_config (SyncConfig): The flags, filters and user for the sync. sync_result_builder(SyncResultBuilder)
now_token (StreamToken): Where the server is currently up to. now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client since_token (StreamToken): Where the server was when the client
last synced. last synced.
@ -248,10 +248,12 @@ class SyncHandler(object):
typing events for that room. typing events for that room.
""" """
sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"): with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string()) room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
@ -565,10 +567,22 @@ class SyncHandler(object):
# Always use the `now_token` in `SyncResultBuilder` # Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
joined_room_ids = yield self.get_rooms_for_user_at(
user_id, now_token.room_stream_id,
)
sync_result_builder = SyncResultBuilder( sync_result_builder = SyncResultBuilder(
sync_config, full_state, sync_config, full_state,
since_token=since_token, since_token=since_token,
now_token=now_token, now_token=now_token,
joined_room_ids=joined_room_ids,
) )
account_data_by_room = yield self._generate_sync_entry_for_account_data( account_data_by_room = yield self._generate_sync_entry_for_account_data(
@ -603,7 +617,6 @@ class SyncHandler(object):
device_id = sync_config.device_id device_id = sync_config.device_id
one_time_key_counts = {} one_time_key_counts = {}
if device_id: if device_id:
user_id = sync_config.user.to_string()
one_time_key_counts = yield self.store.count_e2e_one_time_keys( one_time_key_counts = yield self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
) )
@ -891,7 +904,7 @@ class SyncHandler(object):
ephemeral_by_room = {} ephemeral_by_room = {}
else: else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room( now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_result_builder.sync_config, sync_result_builder,
now_token=sync_result_builder.now_token, now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token, since_token=sync_result_builder.since_token,
) )
@ -996,16 +1009,8 @@ class SyncHandler(object):
if rooms_changed: if rooms_changed:
defer.returnValue(True) defer.returnValue(True)
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
for room_id in joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id): if self.store.has_room_changed_since(room_id, stream_id):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
@ -1029,14 +1034,6 @@ class SyncHandler(object):
assert since_token assert since_token
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
@ -1059,7 +1056,7 @@ class SyncHandler(object):
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
old_state_ids = None old_state_ids = None
if room_id in joined_room_ids and non_joins: if room_id in sync_result_builder.joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially # Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly. # important so that device list changes are calculated correctly.
# If there are non join member events, but we are still in the room, # If there are non join member events, but we are still in the room,
@ -1069,7 +1066,7 @@ class SyncHandler(object):
# User is in the room so we don't need to do the invite/leave checks # User is in the room so we don't need to do the invite/leave checks
continue continue
if room_id in joined_room_ids or has_join: if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token) old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None old_mem_ev = None
@ -1081,7 +1078,7 @@ class SyncHandler(object):
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
# If user is in the room then we don't need to do the invite/leave checks # If user is in the room then we don't need to do the invite/leave checks
if room_id in joined_room_ids: if room_id in sync_result_builder.joined_room_ids:
continue continue
if not non_joins: if not non_joins:
@ -1148,7 +1145,7 @@ class SyncHandler(object):
# Get all events for rooms we're currently joined to. # Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms( room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=joined_room_ids, room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key, from_key=since_token.room_key,
to_key=now_token.room_key, to_key=now_token.room_key,
limit=timeline_limit + 1, limit=timeline_limit + 1,
@ -1156,7 +1153,7 @@ class SyncHandler(object):
# We loop through all room ids, even if there are no new events, in case # We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about. # there are non room events taht we need to notify about.
for room_id in joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
room_entry = room_to_events.get(room_id, None) room_entry = room_to_events.get(room_id, None)
if room_entry: if room_entry:
@ -1364,6 +1361,54 @@ class SyncHandler(object):
else: else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype) raise Exception("Unrecognized rtype: %r", room_builder.rtype)
@defer.inlineCallbacks
def get_rooms_for_user_at(self, user_id, stream_ordering):
"""Get set of joined rooms for a user at the given stream ordering.
The stream ordering *must* be recent, otherwise this may throw an
exception if older than a month. (This function is called with the
current token, which should be perfectly fine).
Args:
user_id (str)
stream_ordering (int)
ReturnValue:
Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering.
"""
joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
user_id,
)
joined_room_ids = set()
# We need to check that the stream ordering of the join for each room
# is before the stream_ordering asked for. This might not be the case
# if the user joins a room between us getting the current token and
# calling `get_rooms_for_user_with_stream_ordering`.
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
for room_id, membeship_stream_ordering in joined_rooms:
if membeship_stream_ordering <= stream_ordering:
joined_room_ids.add(room_id)
continue
logger.info("SH joined_room_ids membership after current token")
extrems = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering,
)
users_in_room = yield self.state.get_current_user_in_room(
room_id, extrems,
)
if user_id in users_in_room:
joined_room_ids.add(room_id)
joined_room_ids = frozenset(joined_room_ids)
defer.returnValue(joined_room_ids)
def _action_has_highlight(actions): def _action_has_highlight(actions):
for action in actions: for action in actions:
@ -1413,7 +1458,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
class SyncResultBuilder(object): class SyncResultBuilder(object):
"Used to help build up a new SyncResult for a user" "Used to help build up a new SyncResult for a user"
def __init__(self, sync_config, full_state, since_token, now_token): def __init__(self, sync_config, full_state, since_token, now_token,
joined_room_ids):
""" """
Args: Args:
sync_config(SyncConfig) sync_config(SyncConfig)
@ -1425,6 +1471,7 @@ class SyncResultBuilder(object):
self.full_state = full_state self.full_state = full_state
self.since_token = since_token self.since_token = since_token
self.now_token = now_token self.now_token = now_token
self.joined_room_ids = joined_room_ids
self.presence = [] self.presence = []
self.account_data = [] self.account_data = []

View File

@ -754,7 +754,7 @@ class EventsStore(EventsWorkerStore):
for member in members_changed: for member in members_changed:
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_rooms_for_user, (member,) txn, self.get_rooms_for_user_with_stream_ordering, (member,)
) )
for host in set(get_domain_from_id(u) for u in members_changed): for host in set(get_domain_from_id(u) for u in members_changed):

View File

@ -38,6 +38,11 @@ RoomsForUser = namedtuple(
("room_id", "sender", "membership", "event_id", "stream_ordering") ("room_id", "sender", "membership", "event_id", "stream_ordering")
) )
GetRoomsForUserWithStreamOrdering = namedtuple(
"_GetRoomsForUserWithStreamOrdering",
("room_id", "stream_ordering",)
)
# We store this using a namedtuple so that we save about 3x space over using a # We store this using a namedtuple so that we save about 3x space over using a
# dict. # dict.
@ -181,12 +186,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results return results
@cachedInlineCallbacks(max_entries=500000, iterable=True) @cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id): def get_rooms_for_user_with_stream_ordering(self, user_id):
"""Returns a set of room_ids the user is currently joined to """Returns a set of room_ids the user is currently joined to
Args:
user_id (str)
Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
""" """
rooms = yield self.get_rooms_for_user_where_membership_is( rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
defer.returnValue(frozenset(
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
for r in rooms
))
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate,
)
defer.returnValue(frozenset(r.room_id for r in rooms)) defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)