Replace context.current_state with context.current_state_ids
This commit is contained in:
parent
17f4f14df7
commit
a3dc1e9cbe
|
@ -52,7 +52,7 @@ class Auth(object):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||||
# Docs for these currently lives at
|
# Docs for these currently lives at
|
||||||
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||||
# In addition, we have type == delete_pusher which grants access only to
|
# In addition, we have type == delete_pusher which grants access only to
|
||||||
# delete pushers.
|
# delete pushers.
|
||||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||||
|
@ -63,6 +63,17 @@ class Auth(object):
|
||||||
"user_id = ",
|
"user_id = ",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_from_context(self, event, context, do_sig_check=True):
|
||||||
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
|
event, context.current_state_ids, for_verification=True,
|
||||||
|
)
|
||||||
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
|
auth_events = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events.values()
|
||||||
|
}
|
||||||
|
self.check(event, auth_events=auth_events, do_sig_check=False)
|
||||||
|
|
||||||
def check(self, event, auth_events, do_sig_check=True):
|
def check(self, event, auth_events, do_sig_check=True):
|
||||||
""" Checks if this event is correctly authed.
|
""" Checks if this event is correctly authed.
|
||||||
|
|
||||||
|
@ -847,7 +858,7 @@ class Auth(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_auth_events(self, builder, context):
|
def add_auth_events(self, builder, context):
|
||||||
auth_ids = self.compute_auth_events(builder, context.current_state)
|
auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
|
||||||
|
|
||||||
auth_events_entries = yield self.store.add_event_hashes(
|
auth_events_entries = yield self.store.add_event_hashes(
|
||||||
auth_ids
|
auth_ids
|
||||||
|
@ -855,30 +866,32 @@ class Auth(object):
|
||||||
|
|
||||||
builder.auth_events = auth_events_entries
|
builder.auth_events = auth_events_entries
|
||||||
|
|
||||||
def compute_auth_events(self, event, current_state):
|
@defer.inlineCallbacks
|
||||||
|
def compute_auth_events(self, event, current_state_ids, for_verification=False):
|
||||||
if event.type == EventTypes.Create:
|
if event.type == EventTypes.Create:
|
||||||
return []
|
defer.returnValue([])
|
||||||
|
|
||||||
auth_ids = []
|
auth_ids = []
|
||||||
|
|
||||||
key = (EventTypes.PowerLevels, "", )
|
key = (EventTypes.PowerLevels, "", )
|
||||||
power_level_event = current_state.get(key)
|
power_level_event_id = current_state_ids.get(key)
|
||||||
|
|
||||||
if power_level_event:
|
if power_level_event_id:
|
||||||
auth_ids.append(power_level_event.event_id)
|
auth_ids.append(power_level_event_id)
|
||||||
|
|
||||||
key = (EventTypes.JoinRules, "", )
|
key = (EventTypes.JoinRules, "", )
|
||||||
join_rule_event = current_state.get(key)
|
join_rule_event_id = current_state_ids.get(key)
|
||||||
|
|
||||||
key = (EventTypes.Member, event.user_id, )
|
key = (EventTypes.Member, event.user_id, )
|
||||||
member_event = current_state.get(key)
|
member_event_id = current_state_ids.get(key)
|
||||||
|
|
||||||
key = (EventTypes.Create, "", )
|
key = (EventTypes.Create, "", )
|
||||||
create_event = current_state.get(key)
|
create_event_id = current_state_ids.get(key)
|
||||||
if create_event:
|
if create_event_id:
|
||||||
auth_ids.append(create_event.event_id)
|
auth_ids.append(create_event_id)
|
||||||
|
|
||||||
if join_rule_event:
|
if join_rule_event_id:
|
||||||
|
join_rule_event = yield self.store.get_event(join_rule_event_id)
|
||||||
join_rule = join_rule_event.content.get("join_rule")
|
join_rule = join_rule_event.content.get("join_rule")
|
||||||
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
||||||
else:
|
else:
|
||||||
|
@ -887,15 +900,21 @@ class Auth(object):
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
e_type = event.content["membership"]
|
e_type = event.content["membership"]
|
||||||
if e_type in [Membership.JOIN, Membership.INVITE]:
|
if e_type in [Membership.JOIN, Membership.INVITE]:
|
||||||
if join_rule_event:
|
if join_rule_event_id:
|
||||||
auth_ids.append(join_rule_event.event_id)
|
auth_ids.append(join_rule_event_id)
|
||||||
|
|
||||||
if e_type == Membership.JOIN:
|
if e_type == Membership.JOIN:
|
||||||
if member_event and not is_public:
|
if member_event_id and not is_public:
|
||||||
auth_ids.append(member_event.event_id)
|
auth_ids.append(member_event_id)
|
||||||
else:
|
else:
|
||||||
if member_event:
|
if member_event_id:
|
||||||
auth_ids.append(member_event.event_id)
|
auth_ids.append(member_event_id)
|
||||||
|
|
||||||
|
if for_verification:
|
||||||
|
key = (EventTypes.Member, event.state_key, )
|
||||||
|
existing_event_id = current_state_ids.get(key)
|
||||||
|
if existing_event_id:
|
||||||
|
auth_ids.append(existing_event_id)
|
||||||
|
|
||||||
if e_type == Membership.INVITE:
|
if e_type == Membership.INVITE:
|
||||||
if "third_party_invite" in event.content:
|
if "third_party_invite" in event.content:
|
||||||
|
@ -903,14 +922,15 @@ class Auth(object):
|
||||||
EventTypes.ThirdPartyInvite,
|
EventTypes.ThirdPartyInvite,
|
||||||
event.content["third_party_invite"]["signed"]["token"]
|
event.content["third_party_invite"]["signed"]["token"]
|
||||||
)
|
)
|
||||||
third_party_invite = current_state.get(key)
|
third_party_invite_id = current_state_ids.get(key)
|
||||||
if third_party_invite:
|
if third_party_invite_id:
|
||||||
auth_ids.append(third_party_invite.event_id)
|
auth_ids.append(third_party_invite_id)
|
||||||
elif member_event:
|
elif member_event_id:
|
||||||
|
member_event = yield self.store.get_event(member_event_id)
|
||||||
if member_event.content["membership"] == Membership.JOIN:
|
if member_event.content["membership"] == Membership.JOIN:
|
||||||
auth_ids.append(member_event.event_id)
|
auth_ids.append(member_event.event_id)
|
||||||
|
|
||||||
return auth_ids
|
defer.returnValue(auth_ids)
|
||||||
|
|
||||||
def _get_send_level(self, etype, state_key, auth_events):
|
def _get_send_level(self, etype, state_key, auth_events):
|
||||||
key = (EventTypes.PowerLevels, "", )
|
key = (EventTypes.PowerLevels, "", )
|
||||||
|
|
|
@ -15,17 +15,8 @@
|
||||||
|
|
||||||
|
|
||||||
class EventContext(object):
|
class EventContext(object):
|
||||||
def _set_current_state(self, current_state):
|
def __init__(self, current_state_ids=None):
|
||||||
if current_state is not None:
|
self.current_state_ids = current_state_ids
|
||||||
self.current_state_ids = {k: e.event_id for k, e in current_state.items()}
|
|
||||||
else:
|
|
||||||
self.current_state_ids = None
|
|
||||||
self._current_state = current_state
|
|
||||||
|
|
||||||
current_state = property(lambda self: self._current_state, _set_current_state)
|
|
||||||
|
|
||||||
def __init__(self, current_state=None):
|
|
||||||
self.current_state = current_state
|
|
||||||
self.state_group = None
|
self.state_group = None
|
||||||
self.rejected = False
|
self.rejected = False
|
||||||
self.push_actions = []
|
self.push_actions = []
|
||||||
|
|
|
@ -65,33 +65,21 @@ class BaseHandler(object):
|
||||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_host_in_room(self, current_state):
|
|
||||||
room_members = [
|
|
||||||
(state_key, event.membership)
|
|
||||||
for ((event_type, state_key), event) in current_state.items()
|
|
||||||
if event_type == EventTypes.Member
|
|
||||||
]
|
|
||||||
if len(room_members) == 0:
|
|
||||||
# Have we just created the room, and is this about to be the very
|
|
||||||
# first member event?
|
|
||||||
create_event = current_state.get(("m.room.create", ""))
|
|
||||||
if create_event:
|
|
||||||
return True
|
|
||||||
for (state_key, membership) in room_members:
|
|
||||||
if (
|
|
||||||
self.hs.is_mine_id(state_key)
|
|
||||||
and membership == Membership.JOIN
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def maybe_kick_guest_users(self, event, current_state):
|
def maybe_kick_guest_users(self, event, context=None):
|
||||||
# Technically this function invalidates current_state by changing it.
|
# Technically this function invalidates current_state by changing it.
|
||||||
# Hopefully this isn't that important to the caller.
|
# Hopefully this isn't that important to the caller.
|
||||||
if event.type == EventTypes.GuestAccess:
|
if event.type == EventTypes.GuestAccess:
|
||||||
guest_access = event.content.get("guest_access", "forbidden")
|
guest_access = event.content.get("guest_access", "forbidden")
|
||||||
if guest_access != "can_join":
|
if guest_access != "can_join":
|
||||||
|
if context:
|
||||||
|
current_state = yield self.store.get_events(
|
||||||
|
context.current_state_ids.values()
|
||||||
|
)
|
||||||
|
current_state = current_state.values()
|
||||||
|
else:
|
||||||
|
current_state = yield self.store.get_current_state(event.room_id)
|
||||||
|
logger.info("maybe_kick_guest_users %r", current_state)
|
||||||
yield self.kick_guest_users(current_state)
|
yield self.kick_guest_users(current_state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -217,11 +217,21 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
prev_state = context.current_state.get((event.type, event.state_key))
|
# Only fire user_joined_room if the user has acutally
|
||||||
if not prev_state or prev_state.membership != Membership.JOIN:
|
# joined the room. Don't bother if the user is just
|
||||||
# Only fire user_joined_room if the user has acutally
|
# changing their profile info.
|
||||||
# joined the room. Don't bother if the user is just
|
newly_joined = True
|
||||||
# changing their profile info.
|
prev_state_id = context.current_state_ids.get(
|
||||||
|
(event.type, event.state_key)
|
||||||
|
)
|
||||||
|
if prev_state_id:
|
||||||
|
prev_state = yield self.store.get_event(
|
||||||
|
prev_state_id, allow_none=True,
|
||||||
|
)
|
||||||
|
if prev_state and prev_state.membership == Membership.JOIN:
|
||||||
|
newly_joined = False
|
||||||
|
|
||||||
|
if newly_joined:
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield user_joined_room(self.distributor, user, event.room_id)
|
||||||
|
|
||||||
|
@ -734,7 +744,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_join_request`
|
# when we get the event back in `on_send_join_request`
|
||||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
|
@ -782,18 +792,11 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
new_pdu = event
|
new_pdu = event
|
||||||
|
|
||||||
destinations = set()
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
|
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||||
for k, s in context.current_state.items():
|
context
|
||||||
try:
|
)
|
||||||
if k[0] == EventTypes.Member:
|
destinations = set(destinations)
|
||||||
if s.content["membership"] == Membership.JOIN:
|
|
||||||
destinations.add(get_domain_from_id(s.state_key))
|
|
||||||
except:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to get destination from event %s", s.event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
destinations.discard(origin)
|
destinations.discard(origin)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -804,13 +807,15 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||||
|
|
||||||
state_ids = [e.event_id for e in context.current_state.values()]
|
state_ids = context.current_state_ids.values()
|
||||||
auth_chain = yield self.store.get_auth_chain(set(
|
auth_chain = yield self.store.get_auth_chain(set(
|
||||||
[event.event_id] + state_ids
|
[event.event_id] + state_ids
|
||||||
))
|
))
|
||||||
|
|
||||||
|
state = yield self.store.get_events(context.current_state_ids.values())
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"state": context.current_state.values(),
|
"state": state.values(),
|
||||||
"auth_chain": auth_chain,
|
"auth_chain": auth_chain,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -966,7 +971,7 @@ class FederationHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_leave_request`
|
# when we get the event back in `on_send_leave_request`
|
||||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Failed to create new leave %r because %s", event, e)
|
logger.warn("Failed to create new leave %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -1010,18 +1015,11 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
new_pdu = event
|
new_pdu = event
|
||||||
|
|
||||||
destinations = set()
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
|
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||||
for k, s in context.current_state.items():
|
context
|
||||||
try:
|
)
|
||||||
if k[0] == EventTypes.Member:
|
destinations = set(destinations)
|
||||||
if s.content["membership"] == Membership.LEAVE:
|
|
||||||
destinations.add(get_domain_from_id(s.state_key))
|
|
||||||
except:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to get destination from event %s", s.event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
destinations.discard(origin)
|
destinations.discard(origin)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -1306,7 +1304,13 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
auth_events = context.current_state
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
|
event, context.current_state_ids, for_verification=True,
|
||||||
|
)
|
||||||
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
|
auth_events = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events.values()
|
||||||
|
}
|
||||||
|
|
||||||
# This is a hack to fix some old rooms where the initial join event
|
# This is a hack to fix some old rooms where the initial join event
|
||||||
# didn't reference the create event in its auth events.
|
# didn't reference the create event in its auth events.
|
||||||
|
@ -1332,8 +1336,7 @@ class FederationHandler(BaseHandler):
|
||||||
context.rejected = RejectedReason.AUTH_ERROR
|
context.rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
if event.type == EventTypes.GuestAccess:
|
if event.type == EventTypes.GuestAccess:
|
||||||
full_context = yield self.store.get_current_state(room_id=event.room_id)
|
yield self.maybe_kick_guest_users(event)
|
||||||
yield self.maybe_kick_guest_users(event, full_context)
|
|
||||||
|
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
|
@ -1504,7 +1507,9 @@ class FederationHandler(BaseHandler):
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
different_auth = event_auth_events - current_state
|
different_auth = event_auth_events - current_state
|
||||||
|
|
||||||
context.current_state.update(auth_events)
|
context.current_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
context.state_group = None
|
context.state_group = None
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
if different_auth and not event.internal_metadata.is_outlier():
|
||||||
|
@ -1526,8 +1531,8 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if do_resolution:
|
if do_resolution:
|
||||||
# 1. Get what we think is the auth chain.
|
# 1. Get what we think is the auth chain.
|
||||||
auth_ids = self.auth.compute_auth_events(
|
auth_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state
|
event, context.current_state_ids
|
||||||
)
|
)
|
||||||
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
||||||
|
|
||||||
|
@ -1583,7 +1588,9 @@ class FederationHandler(BaseHandler):
|
||||||
# 4. Look at rejects and their proofs.
|
# 4. Look at rejects and their proofs.
|
||||||
# TODO.
|
# TODO.
|
||||||
|
|
||||||
context.current_state.update(auth_events)
|
context.current_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
context.state_group = None
|
context.state_group = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1770,12 +1777,12 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, context.current_state)
|
yield self.auth.check_from_context(event, context)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Denying new third party invite %r because %s", event, e)
|
logger.warn("Denying new third party invite %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
yield self._check_signature(event, auth_events=context.current_state)
|
yield self._check_signature(event, context)
|
||||||
member_handler = self.hs.get_handlers().room_member_handler
|
member_handler = self.hs.get_handlers().room_member_handler
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
|
@ -1801,11 +1808,11 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
self.auth.check_from_context(event, context)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Denying third party invite %r because %s", event, e)
|
logger.warn("Denying third party invite %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
yield self._check_signature(event, auth_events=context.current_state)
|
yield self._check_signature(event, context)
|
||||||
|
|
||||||
returned_invite = yield self.send_invite(origin, event)
|
returned_invite = yield self.send_invite(origin, event)
|
||||||
# TODO: Make sure the signatures actually are correct.
|
# TODO: Make sure the signatures actually are correct.
|
||||||
|
@ -1819,7 +1826,12 @@ class FederationHandler(BaseHandler):
|
||||||
EventTypes.ThirdPartyInvite,
|
EventTypes.ThirdPartyInvite,
|
||||||
event.content["third_party_invite"]["signed"]["token"]
|
event.content["third_party_invite"]["signed"]["token"]
|
||||||
)
|
)
|
||||||
original_invite = context.current_state.get(key)
|
original_invite = None
|
||||||
|
original_invite_id = context.current_state_ids.get(key)
|
||||||
|
if original_invite_id:
|
||||||
|
original_invite = yield self.store.get_event(
|
||||||
|
original_invite_id, allow_none=True
|
||||||
|
)
|
||||||
if not original_invite:
|
if not original_invite:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Could not find invite event for third_party_invite - "
|
"Could not find invite event for third_party_invite - "
|
||||||
|
@ -1836,13 +1848,13 @@ class FederationHandler(BaseHandler):
|
||||||
defer.returnValue((event, context))
|
defer.returnValue((event, context))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_signature(self, event, auth_events):
|
def _check_signature(self, event, context):
|
||||||
"""
|
"""
|
||||||
Checks that the signature in the event is consistent with its invite.
|
Checks that the signature in the event is consistent with its invite.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (Event): The m.room.member event to check
|
event (Event): The m.room.member event to check
|
||||||
auth_events (dict<(event type, state_key), event>):
|
context (EventContext):
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError: if signature didn't match any keys, or key has been
|
AuthError: if signature didn't match any keys, or key has been
|
||||||
|
@ -1853,10 +1865,14 @@ class FederationHandler(BaseHandler):
|
||||||
signed = event.content["third_party_invite"]["signed"]
|
signed = event.content["third_party_invite"]["signed"]
|
||||||
token = signed["token"]
|
token = signed["token"]
|
||||||
|
|
||||||
invite_event = auth_events.get(
|
invite_event_id = context.current_state_ids.get(
|
||||||
(EventTypes.ThirdPartyInvite, token,)
|
(EventTypes.ThirdPartyInvite, token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
invite_event = None
|
||||||
|
if invite_event_id:
|
||||||
|
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
|
||||||
|
|
||||||
if not invite_event:
|
if not invite_event:
|
||||||
raise AuthError(403, "Could not find invite")
|
raise AuthError(403, "Could not find invite")
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
|
||||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
|
||||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
prev_state = self.deduplicate_state_event(event, context)
|
prev_state = yield self.deduplicate_state_event(event, context)
|
||||||
if prev_state is not None:
|
if prev_state is not None:
|
||||||
defer.returnValue(prev_state)
|
defer.returnValue(prev_state)
|
||||||
|
|
||||||
|
@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
|
||||||
presence = self.hs.get_presence_handler()
|
presence = self.hs.get_presence_handler()
|
||||||
yield presence.bump_presence_active_time(user)
|
yield presence.bump_presence_active_time(user)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def deduplicate_state_event(self, event, context):
|
def deduplicate_state_event(self, event, context):
|
||||||
"""
|
"""
|
||||||
Checks whether event is in the latest resolved state in context.
|
Checks whether event is in the latest resolved state in context.
|
||||||
|
@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
|
||||||
If so, returns the version of the event in context.
|
If so, returns the version of the event in context.
|
||||||
Otherwise, returns None.
|
Otherwise, returns None.
|
||||||
"""
|
"""
|
||||||
prev_event = context.current_state.get((event.type, event.state_key))
|
prev_event_id = context.current_state_ids.get((event.type, event.state_key))
|
||||||
|
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
|
if not prev_event:
|
||||||
|
return
|
||||||
|
|
||||||
if prev_event and event.user_id == prev_event.user_id:
|
if prev_event and event.user_id == prev_event.user_id:
|
||||||
prev_content = encode_canonical_json(prev_event.content)
|
prev_content = encode_canonical_json(prev_event.content)
|
||||||
next_content = encode_canonical_json(event.content)
|
next_content = encode_canonical_json(event.content)
|
||||||
if prev_content == next_content:
|
if prev_content == next_content:
|
||||||
return prev_event
|
defer.returnValue(prev_event)
|
||||||
return None
|
return
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_and_send_nonmember_event(
|
def create_and_send_nonmember_event(
|
||||||
|
@ -803,7 +809,7 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created event %s with current state: %s",
|
"Created event %s with current state: %s",
|
||||||
event.event_id, context.current_state,
|
event.event_id, context.current_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
|
@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
|
||||||
self.ratelimit(requester)
|
self.ratelimit(requester)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
yield self.auth.check_from_context(event, context)
|
||||||
except AuthError as err:
|
except AuthError as err:
|
||||||
logger.warn("Denying new event %r because %s", event, err)
|
logger.warn("Denying new event %r because %s", event, err)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
yield self.maybe_kick_guest_users(event, context)
|
||||||
|
|
||||||
if event.type == EventTypes.CanonicalAlias:
|
if event.type == EventTypes.CanonicalAlias:
|
||||||
# Check the alias is acually valid (at this time at least)
|
# Check the alias is acually valid (at this time at least)
|
||||||
|
@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
|
||||||
e.sender == event.sender
|
e.sender == event.sender
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_to_include_ids = [
|
||||||
|
e_id
|
||||||
|
for k, e_id in context.current_state_ids.items()
|
||||||
|
if k[0] in self.hs.config.room_invite_state_types
|
||||||
|
or k[0] == EventTypes.Member and k[1] == event.sender
|
||||||
|
]
|
||||||
|
|
||||||
|
state_to_include = yield self.store.get_events(state_to_include_ids)
|
||||||
|
|
||||||
event.unsigned["invite_room_state"] = [
|
event.unsigned["invite_room_state"] = [
|
||||||
{
|
{
|
||||||
"type": e.type,
|
"type": e.type,
|
||||||
|
@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
|
||||||
"content": e.content,
|
"content": e.content,
|
||||||
"sender": e.sender,
|
"sender": e.sender,
|
||||||
}
|
}
|
||||||
for k, e in context.current_state.items()
|
for e in state_to_include.values()
|
||||||
if e.type in self.hs.config.room_invite_state_types
|
|
||||||
or is_inviter_member_event(e)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
invitee = UserID.from_string(event.state_key)
|
invitee = UserID.from_string(event.state_key)
|
||||||
|
@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
if self.auth.check_redaction(event, auth_events=context.current_state):
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
|
event, context.current_state_ids, for_verification=True,
|
||||||
|
)
|
||||||
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
|
auth_events = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events.values()
|
||||||
|
}
|
||||||
|
if self.auth.check_redaction(event, auth_events=auth_events):
|
||||||
original_event = yield self.store.get_event(
|
original_event = yield self.store.get_event(
|
||||||
event.redacts,
|
event.redacts,
|
||||||
check_redacted=False,
|
check_redacted=False,
|
||||||
|
@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
|
||||||
"You don't have permission to redact events"
|
"You don't have permission to redact events"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Create and context.current_state:
|
if event.type == EventTypes.Create and context.current_state_ids:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"Changing the room create event is forbidden",
|
"Changing the room create event is forbidden",
|
||||||
|
@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
|
||||||
event_stream_id, max_stream_id
|
event_stream_id, max_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
destinations = set()
|
destinations = yield self.get_joined_hosts_for_room_from_state(context)
|
||||||
for k, s in context.current_state.items():
|
|
||||||
try:
|
|
||||||
if k[0] == EventTypes.Member:
|
|
||||||
if s.content["membership"] == Membership.JOIN:
|
|
||||||
destinations.add(get_domain_from_id(s.state_key))
|
|
||||||
except SynapseError:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to get destination from event %s", s.event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _notify():
|
def _notify():
|
||||||
|
@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
|
||||||
preserve_fn(federation_handler.handle_new_event)(
|
preserve_fn(federation_handler.handle_new_event)(
|
||||||
event, destinations=destinations,
|
event, destinations=destinations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_joined_hosts_for_room_from_state(self, context):
|
||||||
|
state_group = context.state_group
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._get_joined_hosts_for_room_from_state(
|
||||||
|
state_group, context.current_state_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||||
|
def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
|
||||||
|
cache_context):
|
||||||
|
|
||||||
|
# Don't bother getting state for people on the same HS
|
||||||
|
current_state = yield self.store.get_events([
|
||||||
|
e_id for key, e_id in current_state_ids.items()
|
||||||
|
if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
|
||||||
|
])
|
||||||
|
|
||||||
|
destinations = set()
|
||||||
|
for e in current_state.itervalues():
|
||||||
|
try:
|
||||||
|
if e.type == EventTypes.Member:
|
||||||
|
if e.content["membership"] == Membership.JOIN:
|
||||||
|
destinations.add(get_domain_from_id(e.state_key))
|
||||||
|
except SynapseError:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to get destination from event %s", e.event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(destinations)
|
||||||
|
|
|
@ -93,20 +93,26 @@ class RoomMemberHandler(BaseHandler):
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event = context.current_state.get(
|
prev_member_event_id = context.current_state_ids.get(
|
||||||
(EventTypes.Member, target.to_string()),
|
(EventTypes.Member, target.to_string()),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
# Only fire user_joined_room if the user has acutally joined the
|
||||||
# Only fire user_joined_room if the user has acutally joined the
|
# room. Don't bother if the user is just changing their profile
|
||||||
# room. Don't bother if the user is just changing their profile
|
# info.
|
||||||
# info.
|
newly_joined = True
|
||||||
|
if prev_member_event_id:
|
||||||
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
|
if newly_joined:
|
||||||
yield user_joined_room(self.distributor, target, room_id)
|
yield user_joined_room(self.distributor, target, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
if prev_member_event_id:
|
||||||
user_left_room(self.distributor, target, room_id)
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
|
user_left_room(self.distributor, target, room_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remote_join(self, remote_room_hosts, room_id, user, content):
|
def remote_join(self, remote_room_hosts, room_id, user, content):
|
||||||
|
@ -195,29 +201,32 @@ class RoomMemberHandler(BaseHandler):
|
||||||
remote_room_hosts = []
|
remote_room_hosts = []
|
||||||
|
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
current_state = yield self.state_handler.get_current_state(
|
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||||
room_id, latest_event_ids=latest_event_ids,
|
room_id, latest_event_ids=latest_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
old_state = current_state.get((EventTypes.Member, target.to_string()))
|
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
||||||
old_membership = old_state.content.get("membership") if old_state else None
|
if old_state_id:
|
||||||
if action == "unban" and old_membership != "ban":
|
old_state = yield self.store.get_event(old_state_id, allow_none=True)
|
||||||
raise SynapseError(
|
old_membership = old_state.content.get("membership") if old_state else None
|
||||||
403,
|
if action == "unban" and old_membership != "ban":
|
||||||
"Cannot unban user who was not banned (membership=%s)" % old_membership,
|
raise SynapseError(
|
||||||
errcode=Codes.BAD_STATE
|
403,
|
||||||
)
|
"Cannot unban user who was not banned"
|
||||||
if old_membership == "ban" and action != "unban":
|
" (membership=%s)" % old_membership,
|
||||||
raise SynapseError(
|
errcode=Codes.BAD_STATE
|
||||||
403,
|
)
|
||||||
"Cannot %s user who was banned" % (action,),
|
if old_membership == "ban" and action != "unban":
|
||||||
errcode=Codes.BAD_STATE
|
raise SynapseError(
|
||||||
)
|
403,
|
||||||
|
"Cannot %s user who was banned" % (action,),
|
||||||
|
errcode=Codes.BAD_STATE
|
||||||
|
)
|
||||||
|
|
||||||
is_host_in_room = self.is_host_in_room(current_state)
|
is_host_in_room = yield self._is_host_in_room(current_state_ids)
|
||||||
|
|
||||||
if effective_membership_state == Membership.JOIN:
|
if effective_membership_state == Membership.JOIN:
|
||||||
if requester.is_guest and not self._can_guest_join(current_state):
|
if requester.is_guest and not self._can_guest_join(current_state_ids):
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
@ -326,15 +335,17 @@ class RoomMemberHandler(BaseHandler):
|
||||||
requester = synapse.types.create_requester(target_user)
|
requester = synapse.types.create_requester(target_user)
|
||||||
|
|
||||||
message_handler = self.hs.get_handlers().message_handler
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
prev_event = yield message_handler.deduplicate_state_event(event, context)
|
||||||
if prev_event is not None:
|
if prev_event is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest and not self._can_guest_join(context.current_state):
|
if requester.is_guest:
|
||||||
# This should be an auth check, but guests are a local concept,
|
guest_can_join = yield self._can_guest_join(context.current_state_ids)
|
||||||
# so don't really fit into the general auth process.
|
if not guest_can_join:
|
||||||
raise AuthError(403, "Guest access not allowed")
|
# This should be an auth check, but guests are a local concept,
|
||||||
|
# so don't really fit into the general auth process.
|
||||||
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
|
||||||
yield message_handler.handle_new_client_event(
|
yield message_handler.handle_new_client_event(
|
||||||
requester,
|
requester,
|
||||||
|
@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler):
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event = context.current_state.get(
|
prev_member_event_id = context.current_state_ids.get(
|
||||||
(EventTypes.Member, target_user.to_string()),
|
(EventTypes.Member, event.state_key),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
# Only fire user_joined_room if the user has acutally joined the
|
||||||
# Only fire user_joined_room if the user has acutally joined the
|
# room. Don't bother if the user is just changing their profile
|
||||||
# room. Don't bother if the user is just changing their profile
|
# info.
|
||||||
# info.
|
newly_joined = True
|
||||||
|
if prev_member_event_id:
|
||||||
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
|
if newly_joined:
|
||||||
yield user_joined_room(self.distributor, target_user, room_id)
|
yield user_joined_room(self.distributor, target_user, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
if prev_member_event_id:
|
||||||
user_left_room(self.distributor, target_user, room_id)
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
|
user_left_room(self.distributor, target_user, room_id)
|
||||||
|
|
||||||
def _can_guest_join(self, current_state):
|
@defer.inlineCallbacks
|
||||||
|
def _can_guest_join(self, current_state_ids):
|
||||||
"""
|
"""
|
||||||
Returns whether a guest can join a room based on its current state.
|
Returns whether a guest can join a room based on its current state.
|
||||||
"""
|
"""
|
||||||
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
|
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
|
||||||
return (
|
if not guest_access_id:
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
guest_access = yield self.store.get_event(guest_access_id)
|
||||||
|
|
||||||
|
defer.returnValue(
|
||||||
guest_access
|
guest_access
|
||||||
and guest_access.content
|
and guest_access.content
|
||||||
and "guest_access" in guest_access.content
|
and "guest_access" in guest_access.content
|
||||||
|
@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler):
|
||||||
|
|
||||||
if membership:
|
if membership:
|
||||||
yield self.store.forget(user_id, room_id)
|
yield self.store.forget(user_id, room_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _is_host_in_room(self, current_state_ids):
|
||||||
|
# Have we just created the room, and is this about to be the very
|
||||||
|
# first member event?
|
||||||
|
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||||
|
if len(current_state_ids) == 1 and create_event_id:
|
||||||
|
defer.returnValue(self.hs.is_mine_id(create_event_id))
|
||||||
|
|
||||||
|
for (etype, state_key), event_id in current_state_ids.items():
|
||||||
|
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
|
||||||
|
continue
|
||||||
|
|
||||||
|
event = yield self.store.get_event(event_id, allow_none=True)
|
||||||
|
if not event:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.membership == Membership.JOIN:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
|
@ -40,12 +40,12 @@ class ActionGenerator:
|
||||||
def handle_push_actions_for_event(self, event, context):
|
def handle_push_actions_for_event(self, event, context):
|
||||||
with Measure(self.clock, "evaluator_for_event"):
|
with Measure(self.clock, "evaluator_for_event"):
|
||||||
bulk_evaluator = yield evaluator_for_event(
|
bulk_evaluator = yield evaluator_for_event(
|
||||||
event, self.hs, self.store, context.state_group, context.current_state
|
event, self.hs, self.store, context
|
||||||
)
|
)
|
||||||
|
|
||||||
with Measure(self.clock, "action_for_event_by_user"):
|
with Measure(self.clock, "action_for_event_by_user"):
|
||||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||||
event, context.current_state
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
context.push_actions = [
|
context.push_actions = [
|
||||||
|
|
|
@ -19,8 +19,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.visibility import filter_events_for_clients
|
from synapse.visibility import filter_events_for_clients_context
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def evaluator_for_event(event, hs, store, state_group, current_state):
|
def evaluator_for_event(event, hs, store, context):
|
||||||
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
||||||
event.room_id, state_group, current_state
|
event.room_id, context
|
||||||
)
|
)
|
||||||
|
|
||||||
# if this event is an invite event, we may need to run rules for the user
|
# if this event is an invite event, we may need to run rules for the user
|
||||||
|
@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def action_for_event_by_user(self, event, current_state):
|
def action_for_event_by_user(self, event, context):
|
||||||
actions_by_user = {}
|
actions_by_user = {}
|
||||||
|
|
||||||
# None of these users can be peeking since this list of users comes
|
# None of these users can be peeking since this list of users comes
|
||||||
|
@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
|
||||||
(u, False) for u in self.rules_by_user.keys()
|
(u, False) for u in self.rules_by_user.keys()
|
||||||
]
|
]
|
||||||
|
|
||||||
filtered_by_user = yield filter_events_for_clients(
|
filtered_by_user = yield filter_events_for_clients_context(
|
||||||
self.store, user_tuples, [event], {event.event_id: current_state}
|
self.store, user_tuples, [event], {event.event_id: context}
|
||||||
)
|
)
|
||||||
|
|
||||||
room_members = set(
|
room_members = yield self.store.get_joined_users_from_context(
|
||||||
e.state_key for e in current_state.values()
|
event.room_id, context,
|
||||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||||
|
|
||||||
condition_cache = {}
|
condition_cache = {}
|
||||||
|
|
||||||
display_names = {}
|
|
||||||
for ev in current_state.values():
|
|
||||||
nm = ev.content.get("displayname", None)
|
|
||||||
if nm and ev.type == EventTypes.Member:
|
|
||||||
display_names[ev.state_key] = nm
|
|
||||||
|
|
||||||
for uid, rules in self.rules_by_user.items():
|
for uid, rules in self.rules_by_user.items():
|
||||||
display_name = display_names.get(uid, None)
|
display_name = None
|
||||||
|
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
|
||||||
|
if member_ev_id:
|
||||||
|
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
|
||||||
|
if member_ev:
|
||||||
|
display_name = member_ev.content.get("displayname", None)
|
||||||
|
|
||||||
filtered = filtered_by_user[uid]
|
filtered = filtered_by_user[uid]
|
||||||
if len(filtered) == 0:
|
if len(filtered) == 0:
|
||||||
|
|
|
@ -106,6 +106,20 @@ class StateHandler(object):
|
||||||
|
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_current_state_ids(self, room_id, event_type=None, state_key="",
|
||||||
|
latest_event_ids=None):
|
||||||
|
if not latest_event_ids:
|
||||||
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
|
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
|
|
||||||
|
if event_type:
|
||||||
|
defer.returnValue(state.get((event_type, state_key)))
|
||||||
|
return
|
||||||
|
|
||||||
|
defer.returnValue(state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def compute_event_context(self, event, old_state=None):
|
def compute_event_context(self, event, old_state=None):
|
||||||
""" Fills out the context with the `current state` of the graph. The
|
""" Fills out the context with the `current state` of the graph. The
|
||||||
|
@ -127,27 +141,27 @@ class StateHandler(object):
|
||||||
# state. Certainly store.get_current_state won't return any, and
|
# state. Certainly store.get_current_state won't return any, and
|
||||||
# persisting the event won't store the state group.
|
# persisting the event won't store the state group.
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state = {
|
context.current_state_ids = {
|
||||||
(s.type, s.state_key): s for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
context.current_state = {}
|
context.current_state_ids = {}
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
context.state_group = None
|
context.state_group = None
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state = {
|
context.current_state_ids = {
|
||||||
(s.type, s.state_key): s for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
context.state_group = None
|
context.state_group = None
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state:
|
if key in context.current_state_ids:
|
||||||
replaces = context.current_state[key]
|
replaces = context.current_state_ids[key]
|
||||||
if replaces.event_id != event.event_id: # Paranoia check
|
if replaces != event.event_id: # Paranoia check
|
||||||
event.unsigned["replaces_state"] = replaces.event_id
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
@ -165,22 +179,14 @@ class StateHandler(object):
|
||||||
|
|
||||||
group, curr_state = ret
|
group, curr_state = ret
|
||||||
|
|
||||||
state_map = yield self.store.get_events(
|
context.current_state_ids = curr_state
|
||||||
curr_state.values(),
|
|
||||||
get_prev_content=False
|
|
||||||
)
|
|
||||||
curr_state = {
|
|
||||||
key: state_map[e_id] for key, e_id in curr_state.items() if e_id in state_map
|
|
||||||
}
|
|
||||||
|
|
||||||
context.current_state = curr_state
|
|
||||||
context.state_group = group if not event.is_state() else None
|
context.state_group = group if not event.is_state() else None
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state:
|
if key in context.current_state_ids:
|
||||||
replaces = context.current_state[key]
|
replaces = context.current_state_ids[key]
|
||||||
event.unsigned["replaces_state"] = replaces.event_id
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
|
@ -124,7 +124,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
|
def bulk_get_push_rules_for_room(self, room_id, context):
|
||||||
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
@ -132,10 +133,12 @@ class PushRuleStore(SQLBaseStore):
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
|
return self._bulk_get_push_rules_for_room(
|
||||||
|
room_id, state_group, context.current_state_ids
|
||||||
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
|
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
|
||||||
cache_context):
|
cache_context):
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
# on it. However, its important that its never None, since two current_state's
|
# on it. However, its important that its never None, since two current_state's
|
||||||
|
@ -147,10 +150,16 @@ class PushRuleStore(SQLBaseStore):
|
||||||
# their unread countss are correct in the event stream, but to avoid
|
# their unread countss are correct in the event stream, but to avoid
|
||||||
# generating them for bot / AS users etc, we only do so for people who've
|
# generating them for bot / AS users etc, we only do so for people who've
|
||||||
# sent a read receipt into the room.
|
# sent a read receipt into the room.
|
||||||
|
local_user_member_ids = [
|
||||||
|
e_id for (etype, state_key), e_id in current_state_ids.iteritems()
|
||||||
|
if etype == EventTypes.Member and self.hs.is_mine_id(state_key)
|
||||||
|
]
|
||||||
|
|
||||||
|
local_member_events = yield self._get_events(local_user_member_ids)
|
||||||
|
|
||||||
local_users_in_room = set(
|
local_users_in_room = set(
|
||||||
e.state_key for e in current_state.values()
|
member_event.state_key for member_event in local_member_events
|
||||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
if member_event.membership == Membership.JOIN
|
||||||
and self.hs.is_mine_id(e.state_key)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# users in the room who have pushers need to get push rules run because
|
# users in the room who have pushers need to get push rules run because
|
||||||
|
|
|
@ -20,7 +20,7 @@ from collections import namedtuple
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -325,7 +325,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3)
|
@cachedInlineCallbacks(num_args=3)
|
||||||
def was_forgotten_at(self, user_id, room_id, event_id):
|
def was_forgotten_at(self, user_id, room_id, event_id):
|
||||||
"""Returns whether user_id has elected to discard history for room_id at event_id.
|
"""Returns whether user_id has elected to discard history for room_id at
|
||||||
|
event_id.
|
||||||
|
|
||||||
event_id must be a membership event."""
|
event_id must be a membership event."""
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -358,3 +359,43 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
desc="who_forgot"
|
desc="who_forgot"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_joined_users_from_context(self, room_id, context):
|
||||||
|
state_group = context.state_group
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._get_joined_users_from_context(
|
||||||
|
room_id, state_group, context.current_state_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
|
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
|
||||||
|
cache_context):
|
||||||
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
|
# on it. However, its important that its never None, since two current_state's
|
||||||
|
# with a state_group of None are likely to be different.
|
||||||
|
# See bulk_get_push_rules_for_room for how we work around this.
|
||||||
|
assert state_group is not None
|
||||||
|
|
||||||
|
member_event_ids = [
|
||||||
|
e_id
|
||||||
|
for key, e_id in current_state_ids.iteritems()
|
||||||
|
if key[0] == EventTypes.Member
|
||||||
|
]
|
||||||
|
|
||||||
|
rows = yield self._simple_select_many_batch(
|
||||||
|
table="room_memberships",
|
||||||
|
column="event_id",
|
||||||
|
iterable=member_event_ids,
|
||||||
|
retcols=['user_id'],
|
||||||
|
keyvalues={
|
||||||
|
"membership": Membership.JOIN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(set(row["user_id"] for row in rows))
|
||||||
|
|
|
@ -89,17 +89,17 @@ class StateStore(SQLBaseStore):
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.current_state is None:
|
if context.current_state_ids is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.state_group is not None:
|
if context.state_group is not None:
|
||||||
state_groups[event.event_id] = context.state_group
|
state_groups[event.event_id] = context.state_group
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state_events = dict(context.current_state)
|
state_event_ids = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
state_events[(event.type, event.state_key)] = event
|
state_event_ids[(event.type, event.state_key)] = event.event_id
|
||||||
|
|
||||||
state_group = context.new_state_group_id
|
state_group = context.new_state_group_id
|
||||||
|
|
||||||
|
@ -119,12 +119,12 @@ class StateStore(SQLBaseStore):
|
||||||
values=[
|
values=[
|
||||||
{
|
{
|
||||||
"state_group": state_group,
|
"state_group": state_group,
|
||||||
"room_id": state.room_id,
|
"room_id": event.room_id,
|
||||||
"type": state.type,
|
"type": key[0],
|
||||||
"state_key": state.state_key,
|
"state_key": key[1],
|
||||||
"event_id": state.event_id,
|
"event_id": state_id,
|
||||||
}
|
}
|
||||||
for state in state_events.values()
|
for key, state_id in state_event_ids.items()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
state_groups[event.event_id] = state_group
|
state_groups[event.event_id] = state_group
|
||||||
|
|
|
@ -180,6 +180,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
|
||||||
|
user_ids = set(u[0] for u in user_tuples)
|
||||||
|
event_id_to_state = {}
|
||||||
|
for event_id, context in event_id_to_context.items():
|
||||||
|
state = yield store.get_events([
|
||||||
|
e_id
|
||||||
|
for key, e_id in context.current_state_ids.iteritems()
|
||||||
|
if key == (EventTypes.RoomHistoryVisibility, "")
|
||||||
|
or (key[0] == EventTypes.Member and key[1] in user_ids)
|
||||||
|
])
|
||||||
|
event_id_to_state[event_id] = state
|
||||||
|
|
||||||
|
res = yield filter_events_for_clients(
|
||||||
|
store, user_tuples, events, event_id_to_state
|
||||||
|
)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def filter_events_for_client(store, user_id, events, is_peeking=False):
|
def filter_events_for_client(store, user_id, events, is_peeking=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -305,7 +305,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
self.event_id += 1
|
self.event_id += 1
|
||||||
|
|
||||||
context = EventContext(current_state=state)
|
if state is not None:
|
||||||
|
state_ids = {
|
||||||
|
key: e.event_id for key, e in state.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
state_ids = None
|
||||||
|
|
||||||
|
context = EventContext(current_state_ids=state_ids)
|
||||||
context.push_actions = push_actions
|
context.push_actions = push_actions
|
||||||
|
|
||||||
ordering = None
|
ordering = None
|
||||||
|
|
|
@ -69,7 +69,7 @@ class StateGroupStore(object):
|
||||||
|
|
||||||
self._next_group = 1
|
self._next_group = 1
|
||||||
|
|
||||||
def get_state_groups(self, room_id, event_ids):
|
def get_state_groups_ids(self, room_id, event_ids):
|
||||||
groups = {}
|
groups = {}
|
||||||
for event_id in event_ids:
|
for event_id in event_ids:
|
||||||
group = self._event_to_state_group.get(event_id)
|
group = self._event_to_state_group.get(event_id)
|
||||||
|
@ -79,20 +79,20 @@ class StateGroupStore(object):
|
||||||
return defer.succeed(groups)
|
return defer.succeed(groups)
|
||||||
|
|
||||||
def store_state_groups(self, event, context):
|
def store_state_groups(self, event, context):
|
||||||
if context.current_state is None:
|
if context.current_state_ids is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
state_events = context.current_state
|
state_events = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
state_events[(event.type, event.state_key)] = event
|
state_events[(event.type, event.state_key)] = event.event_id
|
||||||
|
|
||||||
state_group = context.state_group
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
state_group = self._next_group
|
state_group = self._next_group
|
||||||
self._next_group += 1
|
self._next_group += 1
|
||||||
|
|
||||||
self._group_to_state[state_group] = state_events.values()
|
self._group_to_state[state_group] = state_events
|
||||||
|
|
||||||
self._event_to_state_group[event.event_id] = state_group
|
self._event_to_state_group[event.event_id] = state_group
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.store = Mock(
|
self.store = Mock(
|
||||||
spec_set=[
|
spec_set=[
|
||||||
"get_state_groups",
|
"get_state_groups_ids",
|
||||||
"add_event_hashes",
|
"add_event_hashes",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -187,7 +187,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -196,7 +196,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
store.store_state_groups(event, context)
|
store.store_state_groups(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertEqual(2, len(context_store["D"].current_state))
|
self.assertEqual(2, len(context_store["D"].current_state_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_branch_basic_conflict(self):
|
def test_branch_basic_conflict(self):
|
||||||
|
@ -239,7 +239,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
graph = Graph(nodes, edges)
|
graph = Graph(nodes, edges)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -424,13 +424,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(old_state), set(context.current_state.values())
|
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNone(context.state_group)
|
||||||
|
@ -449,14 +444,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(old_state),
|
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||||
set(context.current_state.values())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNone(context.state_group)
|
||||||
|
@ -473,20 +462,15 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
group_name = "group_name_1"
|
group_name = "group_name_1"
|
||||||
|
|
||||||
self.store.get_state_groups.return_value = {
|
self.store.get_state_groups_ids.return_value = {
|
||||||
group_name: old_state,
|
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
}
|
}
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set([e.event_id for e in context.current_state.values()])
|
set(context.current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(group_name, context.state_group)
|
self.assertEqual(group_name, context.state_group)
|
||||||
|
@ -503,20 +487,15 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
group_name = "group_name_1"
|
group_name = "group_name_1"
|
||||||
|
|
||||||
self.store.get_state_groups.return_value = {
|
self.store.get_state_groups_ids.return_value = {
|
||||||
group_name: old_state,
|
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
}
|
}
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set([e.event_id for e in context.current_state.values()])
|
set(context.current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNone(context.state_group)
|
||||||
|
@ -545,7 +524,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNone(context.state_group)
|
||||||
|
|
||||||
|
@ -573,7 +552,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNone(context.state_group)
|
||||||
|
|
||||||
|
@ -608,7 +587,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
|
self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")])
|
||||||
|
|
||||||
# Reverse the depth to make sure we are actually using the depths
|
# Reverse the depth to make sure we are actually using the depths
|
||||||
# during state resolution.
|
# during state resolution.
|
||||||
|
@ -627,15 +606,15 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
|
self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")])
|
||||||
|
|
||||||
def _get_context(self, event, old_state_1, old_state_2):
|
def _get_context(self, event, old_state_1, old_state_2):
|
||||||
group_name_1 = "group_name_1"
|
group_name_1 = "group_name_1"
|
||||||
group_name_2 = "group_name_2"
|
group_name_2 = "group_name_2"
|
||||||
|
|
||||||
self.store.get_state_groups.return_value = {
|
self.store.get_state_groups_ids.return_value = {
|
||||||
group_name_1: old_state_1,
|
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
|
||||||
group_name_2: old_state_2,
|
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.state.compute_event_context(event)
|
return self.state.compute_event_context(event)
|
||||||
|
|
Loading…
Reference in New Issue