Merge branch 'develop' into markjh/direct_to_device_synchrotron

This commit is contained in:
Mark Haines 2016-09-02 10:59:24 +01:00
commit 965168a842
28 changed files with 431 additions and 235 deletions

View File

@ -134,6 +134,12 @@ Installing prerequisites on Raspbian::
sudo pip install --upgrade ndg-httpsclient sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv sudo pip install --upgrade virtualenv
Installing prerequisites on openSUSE::
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@ -230,9 +236,6 @@ The advantages of Postgres include:
pointing at the same DB master, as well as enabling DB replication in pointing at the same DB master, as well as enabling DB replication in
synapse itself. synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
For information on how to install and use PostgreSQL, please see For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_. `docs/postgres.rst <docs/postgres.rst>`_.

View File

@ -66,7 +66,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True): def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
event, context.current_state_ids, for_verification=True, event, context.prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -281,11 +281,13 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, curr_state_ids = yield self.state.resolve_state_groups( entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids room_id, latest_event_ids
) )
ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) ret = yield self.store.is_host_joined(
room_id, host, entry.state_group, entry.state
)
defer.returnValue(ret) defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
@ -852,7 +854,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids

View File

@ -67,6 +67,8 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None response = None
try: try:
@ -86,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_alias(self, service, alias): def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None response = None
try: try:
@ -113,6 +117,8 @@ class ApplicationServiceApi(SimpleHttpClient):
raise ValueError( raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind "Unrecognised 'kind' argument %r to query_3pe()", kind
) )
if service.url is None:
defer.returnValue([])
uri = "%s%s/thirdparty/%s/%s" % ( uri = "%s%s/thirdparty/%s/%s" % (
service.url, service.url,
@ -145,6 +151,9 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue([]) defer.returnValue([])
def get_3pe_protocol(self, service, protocol): def get_3pe_protocol(self, service, protocol):
if service.url is None:
defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def _get(): def _get():
uri = "%s%s/thirdparty/protocol/%s" % ( uri = "%s%s/thirdparty/protocol/%s" % (
@ -166,6 +175,9 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
if service.url is None:
defer.returnValue(True)
events = self._serialize(events) events = self._serialize(events)
if txn_id is None: if txn_id is None:

View File

@ -86,7 +86,7 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename): def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [ required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart" "id", "as_token", "hs_token", "sender_localpart"
] ]
for field in required_string_fields: for field in required_string_fields:
if not isinstance(as_info.get(field), basestring): if not isinstance(as_info.get(field), basestring):
@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename):
field, config_filename, field, config_filename,
)) ))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), basestring) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"] localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart: if urllib.quote(localpart) != localpart:
raise ValueError( raise ValueError(
@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename):
for p in protocols: for p in protocols:
if not isinstance(p, str): if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item") raise KeyError("Bad value for 'protocols' item")
if as_info["url"] is None:
logger.info(
"(%s) Explicitly empty 'url' provided. This application service"
" will not receive events or queries.",
config_filename,
)
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],

View File

@ -15,8 +15,9 @@
class EventContext(object): class EventContext(object):
def __init__(self, current_state_ids=None): def __init__(self):
self.current_state_ids = current_state_ids self.current_state_ids = None
self.prev_state_ids = None
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []

View File

@ -269,7 +269,7 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None signed_pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec() now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0) last_attempt = pdu_attempts.get(destination, 0)
@ -299,7 +299,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hashes([pdu])[0] signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
@ -322,10 +322,10 @@ class FederationClient(FederationBase):
) )
continue continue
if self._get_pdu_cache is not None and pdu: if self._get_pdu_cache is not None and signed_pdu:
self._get_pdu_cache[event_id] = pdu self._get_pdu_cache[event_id] = signed_pdu
defer.returnValue(pdu) defer.returnValue(signed_pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function

View File

@ -222,7 +222,7 @@ class FederationHandler(BaseHandler):
# joined the room. Don't bother if the user is just # joined the room. Don't bother if the user is just
# changing their profile info. # changing their profile info.
newly_joined = True newly_joined = True
prev_state_id = context.current_state_ids.get( prev_state_id = context.prev_state_ids.get(
(event.type, event.state_key) (event.type, event.state_key)
) )
if prev_state_id: if prev_state_id:
@ -835,12 +835,12 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations) self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = context.current_state_ids.values() state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids [event.event_id] + state_ids
)) ))
state = yield self.store.get_events(context.current_state_ids.values()) state = yield self.store.get_events(context.prev_state_ids.values())
defer.returnValue({ defer.returnValue({
"state": state.values(), "state": state.values(),
@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler):
if not auth_events: if not auth_events:
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True, event, context.prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event.is_state():
event_key = (event.type, event.state_key)
else:
event_key = None
if event_auth_events - current_state: if event_auth_events - current_state:
have_events = yield self.store.have_events( have_events = yield self.store.have_events(
event_auth_events - current_state event_auth_events - current_state
@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler):
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key
}) })
context.state_group = None context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth) logger.info("Different auth after resolution: %s", different_auth)
@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler):
if do_resolution: if do_resolution:
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events( auth_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids event, context.prev_state_ids
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler):
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key
}) })
context.state_group = None context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
original_invite = None original_invite = None
original_invite_id = context.current_state_ids.get(key) original_invite_id = context.prev_state_ids.get(key)
if original_invite_id: if original_invite_id:
original_invite = yield self.store.get_event( original_invite = yield self.store.get_event(
original_invite_id, allow_none=True original_invite_id, allow_none=True
@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
token = signed["token"] token = signed["token"]
invite_event_id = context.current_state_ids.get( invite_event_id = context.prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )

View File

@ -272,7 +272,7 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_event_id = context.current_state_ids.get((event.type, event.state_key)) prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True) prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event: if not prev_event:
return return
@ -808,8 +808,8 @@ class MessageHandler(BaseHandler):
event = builder.build() event = builder.build()
logger.debug( logger.debug(
"Created event %s with current state: %s", "Created event %s with state: %s",
event.event_id, context.current_state_ids, event.event_id, context.prev_state_ids,
) )
defer.returnValue( defer.returnValue(
@ -904,7 +904,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True, event, context.prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -924,7 +924,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.current_state_ids: if event.type == EventTypes.Create and context.prev_state_ids:
raise AuthError( raise AuthError(
403, 403,
"Changing the room create event is forbidden", "Changing the room create event is forbidden",

View File

@ -191,6 +191,13 @@ class PresenceHandler(object):
5000, 5000,
) )
self.clock.call_later(
60,
self.clock.looping_call,
self._persist_unpersisted_changes,
60 * 1000,
)
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -216,6 +223,27 @@ class PresenceHandler(object):
]) ])
logger.info("Finished _on_shutdown") logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
logger.info(
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_states(self, new_states): def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes """Updates presence of users. Sets the appropriate timeouts. Pokes
@ -923,6 +951,11 @@ def should_notify(old_state, new_state):
return True return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not (old_state.currently_active and new_state.currently_active):
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped. # Always notify for a transition where last active gets bumped.
return True return True

View File

@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.current_state_ids.get( prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, target.to_string()),
None None
) )
@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest: if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.current_state_ids) guest_can_join = yield self._can_guest_join(context.prev_state_ids)
if not guest_can_join: if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.current_state_ids.get( prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, event.state_key), (EventTypes.Member, event.state_key),
None None
) )

View File

@ -565,7 +565,10 @@ class SyncHandler(object):
if sync_result_builder.since_token is not None: if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key) since_stream_id = int(sync_result_builder.since_token.to_device_key)
if since_stream_id: if since_stream_id != int(now_token.to_device_key):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
logger.debug("Deleting messages up to %d", since_stream_id) logger.debug("Deleting messages up to %d", since_stream_id)
yield self.store.delete_messages_for_device( yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id user_id, device_id, since_stream_id
@ -580,6 +583,8 @@ class SyncHandler(object):
"to_device_key", stream_id "to_device_key", stream_id
) )
sync_result_builder.to_device = messages sync_result_builder.to_device = messages
else:
sync_result_builder.to_device = []
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_account_data(self, sync_result_builder): def _generate_sync_entry_for_account_data(self, sync_result_builder):

View File

@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
) )
room_members = yield self.store.get_joined_users_from_context( room_members = yield self.store.get_joined_users_from_context(
event.room_id, context.state_group, context.current_state_ids event, context
) )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View File

@ -338,7 +338,7 @@ class Mailer(object):
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name( room_name = yield calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False self.store, state_by_room[room_id], user_id, fallback_to_members=False
) )
my_member_event = state_by_room[room_id][("m.room.member", user_id)] my_member_event = state_by_room[room_id][("m.room.member", user_id)]

View File

@ -74,7 +74,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
alias_event = yield store.get_event( alias_event = yield store.get_event(
alias_id, allow_none=True alias_id, allow_none=True
) )
if alias_event and alias_event.content and alias_event.get("aliases"): if alias_event and alias_event.content.get("aliases"):
the_aliases = alias_event.content["aliases"] the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
defer.returnValue(the_aliases[0]) defer.returnValue(the_aliases[0])

View File

@ -40,7 +40,6 @@ STREAM_NAMES = (
("backfill",), ("backfill",),
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",),
("caches",), ("caches",),
("to_device",), ("to_device",),
) )
@ -131,7 +130,6 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
@ -143,7 +141,7 @@ class ReplicationResource(Resource):
backfill_token, backfill_token,
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token, 0, # State stream is no longer a thing
caches_token, caches_token,
int(stream_token.to_device_key), int(stream_token.to_device_key),
)) ))
@ -193,7 +191,6 @@ class ReplicationResource(Resource):
yield self.receipts(writer, current_token, limit, request_streams) yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
@ -368,25 +365,6 @@ class ReplicationResource(Resource):
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
@defer.inlineCallbacks @defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams): def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches current_position = current_token.caches

View File

@ -123,6 +123,7 @@ class SlavedEventStore(BaseSlavedStore):
get_state_groups_ids = DataStore.get_state_groups_ids.__func__ get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = ( _get_joined_users_from_context = (
RoomMemberStore.__dict__["_get_joined_users_from_context"] RoomMemberStore.__dict__["_get_joined_users_from_context"]

View File

@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple
@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
def _gen_state_id():
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
return s
class _StateCacheEntry(object): class _StateCacheEntry(object):
def __init__(self, state, state_group, ts): __slots__ = ["state", "state_group", "state_id"]
def __init__(self, state, state_group):
self.state = state self.state = state
self.state_group = state_group self.state_group = state_group
# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
# group each time we resolve state, so we generate a separate one that
# isn't persisted and is used solely for caches.
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
self.state_id = state_group
else:
self.state_id = _gen_state_id()
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
@ -60,6 +85,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry. # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None self._state_cache = None
self.resolve_linearizer = Linearizer()
def start_caching(self): def start_caching(self):
logger.debug("start_caching") logger.debug("start_caching")
@ -93,7 +119,8 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type: if event_type:
event_id = state.get((event_type, state_key)) event_id = state.get((event_type, state_key))
@ -116,7 +143,8 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type: if event_type:
defer.returnValue(state.get((event_type, state_key))) defer.returnValue(state.get((event_type, state_key)))
@ -127,9 +155,9 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_user_in_room(self, room_id): def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_context( joined_users = yield self.store.get_joined_users_from_state(
room_id, group, state_ids room_id, entry.state_id, entry.state
) )
defer.returnValue(joined_users) defer.returnValue(joined_users)
@ -154,52 +182,73 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
if old_state: if old_state:
context.current_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
if event.is_state():
context.current_state_events = dict(context.prev_state_ids)
key = (event.type, event.state_key)
context.current_state_events[key] = event.event_id
else:
context.current_state_events = context.prev_state_ids
else: else:
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = None context.state_group = self.store.get_next_state_group()
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
context.current_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = None context.state_group = self.store.get_next_state_group()
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state_ids: if key in context.prev_state_ids:
replaces = context.current_state_ids[key] replaces = context.prev_state_ids[key]
if replaces != event.event_id: # Paranoia check if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
event_type=event.type, event_type=event.type,
state_key=event.state_key, state_key=event.state_key,
) )
else: else:
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
group, curr_state = ret curr_state = entry.state
context.current_state_ids = curr_state context.prev_state_ids = curr_state
context.state_group = group if not event.is_state() else None if event.is_state():
context.state_group = self.store.get_next_state_group()
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state_ids: if key in context.prev_state_ids:
replaces = context.current_state_ids[key] replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@ -231,16 +280,16 @@ class StateHandler(object):
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
defer.returnValue((name, state_list,)) defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
))
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
if cache: if cache:
cache.ts = self.clock.time_msec() defer.returnValue(cache)
defer.returnValue(
(cache.state_group, cache.state,)
)
logger.info( logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids) "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
@ -284,17 +333,22 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
if state_group is None:
# Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
if self._state_cache is not None:
cache = _StateCacheEntry( cache = _StateCacheEntry(
state=new_state, state=new_state,
state_group=state_group, state_group=state_group,
ts=self.clock.time_msec()
) )
if self._state_cache is not None:
self._state_cache[group_names] = cache self._state_cache[group_names] = cache
defer.returnValue((state_group, new_state,)) defer.returnValue(cache)
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info( logger.info(

View File

@ -115,7 +115,7 @@ class DataStore(RoomMemberStore, RoomStore,
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")

View File

@ -271,22 +271,11 @@ class EventsStore(SQLBaseStore):
len(events_and_contexts) len(events_and_contexts)
) )
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
with state_group_id_manager as state_group_ids: for (event, context), stream, in zip(
for (event, context), stream, state_group_id in zip( events_and_contexts, stream_orderings
events_and_contexts, stream_orderings, state_group_ids
): ):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = state_group_id
chunks = [ chunks = [
events_and_contexts[x:x + 100] events_and_contexts[x:x + 100]
@ -312,9 +301,7 @@ class EventsStore(SQLBaseStore):
delete_existing=False): delete_existing=False):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id:
event.internal_metadata.stream_ordering = stream_ordering event.internal_metadata.stream_ordering = stream_ordering
context.new_state_group_id = state_group_id
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
@ -528,7 +515,7 @@ class EventsStore(SQLBaseStore):
# Add an entry to the ex_outlier_stream table to replicate the # Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers. # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id state_group_id = context.state_group
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="ex_outlier_stream", table="ex_outlier_stream",

View File

@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.

View File

@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
desc="who_forgot" desc="who_forgot"
) )
def get_joined_users_from_context(self, room_id, state_group, state_ids): def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
state_group = object() state_group = object()
return self._get_joined_users_from_context( return self._get_joined_users_from_context(
room_id, state_group, state_ids event.room_id, state_group, context.current_state_ids, event=event,
)
def get_joined_users_from_state(self, room_id, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_ids,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context): cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
desc="_get_joined_users_from_context", desc="_get_joined_users_from_context",
) )
defer.returnValue(set(row["user_id"] for row in rows)) users_in_room = set(row["user_id"] for row in rows)
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
users_in_room.add(event.state_key)
defer.returnValue(users_in_room)
def is_host_joined(self, room_id, host, state_group, state_ids): def is_host_joined(self, room_id, host, state_group, state_ids):
if not state_group: if not state_group:

View File

@ -0,0 +1,32 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
cur.execute("TRUNCATE sent_transactions")
else:
cur.execute("DELETE FROM sent_transactions")
cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View File

@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.items() for group, event_id_map in group_to_ids.items()
}) })
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts): def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
@ -92,22 +100,19 @@ class StateStore(SQLBaseStore):
if context.current_state_ids is None: if context.current_state_ids is None:
continue continue
if context.state_group is not None:
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
logger.info("Already persisted state_group: %r", context.state_group)
continue continue
state_event_ids = dict(context.current_state_ids) state_event_ids = dict(context.current_state_ids)
if event.is_state():
state_event_ids[(event.type, event.state_key)] = event.event_id
state_group = context.new_state_group_id
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": state_group, "id": context.state_group,
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
}, },
@ -118,7 +123,7 @@ class StateStore(SQLBaseStore):
table="state_groups_state", table="state_groups_state",
values=[ values=[
{ {
"state_group": state_group, "state_group": context.state_group,
"room_id": event.room_id, "room_id": event.room_id,
"type": key[0], "type": key[0],
"state_key": key[1], "state_key": key[1],
@ -127,7 +132,6 @@ class StateStore(SQLBaseStore):
for key, state_id in state_event_ids.items() for key, state_id in state_event_ids.items()
], ],
) )
state_groups[event.event_id] = state_group
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -527,5 +531,5 @@ class StateStore(SQLBaseStore):
"get_all_new_state_groups", get_all_new_state_groups_txn "get_all_new_state_groups", get_all_new_state_groups_txn
) )
def get_state_stream_token(self): def get_next_state_group(self):
return self._state_groups_id_gen.get_current_token() return self._state_groups_id_gen.get_next()

View File

@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions(self): def _cleanup_transactions(self):
now = self._clock.time_msec() now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000 month_ago = now - 30 * 24 * 60 * 60 * 1000
six_hours_ago = now - 6 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn): def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)

View File

@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
), ),
], any_order=True) ], any_order=True)
def test_online_to_online_last_active_noop(self):
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
currently_active=True,
)
new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=now,
)
state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
)
self.assertFalse(persist_and_notify)
self.assertTrue(federation_ping)
self.assertTrue(state.currently_active)
self.assertEquals(new_state.state, state.state)
self.assertEquals(new_state.status_msg, state.status_msg)
self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([
call(
now=now,
obj=user_id,
then=new_state.last_active_ts + IDLE_TIMER
),
call(
now=now,
obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
),
call(
now=now,
obj=user_id,
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
),
], any_order=True)
def test_online_to_online_last_active(self): def test_online_to_online_last_active(self):
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"

View File

@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
else: else:
state_ids = None state_ids = None
context = EventContext(current_state_ids=state_ids) context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context.push_actions = push_actions context.push_actions = push_actions
ordering = None ordering = None

View File

@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body, {}) self.assertEquals(body, {})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_events_and_state(self): def test_events(self):
get = self.get(events="-1", state="-1", timeout="0") get = self.get(events="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room( yield self.hs.get_handlers().room_creation_handler.create_room(
synapse.types.create_requester(self.user), {} synapse.types.create_requester(self.user), {}
) )
@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body["events"]["field_names"], [ self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group"
]) ])
self.assertEquals(body["state_groups"]["field_names"], [
"position", "room_id", "event_id"
])
self.assertEquals(body["state_group_state"]["field_names"], [
"position", "type", "state_key", "event_id"
])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_presence(self): def test_presence(self):

View File

@ -86,17 +86,8 @@ class StateGroupStore(object):
state_events = dict(context.current_state_ids) state_events = dict(context.current_state_ids)
if event.is_state(): self._group_to_state[context.state_group] = state_events
state_events[(event.type, event.state_key)] = event.event_id self._event_to_state_group[event.event_id] = context.state_group
state_group = context.state_group
if not state_group:
state_group = self._next_group
self._next_group += 1
self._group_to_state[state_group] = state_events
self._event_to_state_group[event.event_id] = state_group
def get_events(self, event_ids, **kwargs): def get_events(self, event_ids, **kwargs):
return { return {
@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase):
"get_state_groups_ids", "get_state_groups_ids",
"add_event_hashes", "add_event_hashes",
"get_events", "get_events",
"get_next_state_group",
] ]
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
self.store.get_next_state_group.side_effect = Mock
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context) store.store_state_groups(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state_ids)) self.assertEqual(2, len(context_store["D"].prev_state_ids))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_basic_conflict(self): def test_branch_basic_conflict(self):
@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {"START", "A", "C"},
{e_id for e_id in context_store["D"].current_state_ids.values()} {e_id for e_id in context_store["D"].prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {"START", "A", "B", "C"},
{e for e in context_store["E"].current_state_ids.values()} {e for e in context_store["E"].prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {"A1", "A2", "A3", "A5", "B"},
{e for e in context_store["D"].current_state_ids.values()} {e for e in context_store["D"].prev_state_ids.values()}
) )
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase):
set(e.event_id for e in old_state), set(context.current_state_ids.values()) set(e.event_id for e in old_state), set(context.current_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase):
) )
self.assertEqual( self.assertEqual(
set(e.event_id for e in old_state), set(context.current_state_ids.values()) set(e.event_id for e in old_state), set(context.prev_state_ids.values())
) )
self.assertIsNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase):
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set(context.current_state_ids.values()) set(context.prev_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_standard_depth_conflict(self): def test_standard_depth_conflict(self):