Merge pull request #502 from matrix-org/erikj/push_notif_perf
Unread notification performance.
This commit is contained in:
commit
47e7963e50
|
@ -117,6 +117,15 @@ class EventBase(object):
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance, value):
|
||||||
raise AttributeError("Unrecognized attribute %s" % (instance,))
|
raise AttributeError("Unrecognized attribute %s" % (instance,))
|
||||||
|
|
||||||
|
def __getitem__(self, field):
|
||||||
|
return self._event_dict[field]
|
||||||
|
|
||||||
|
def __contains__(self, field):
|
||||||
|
return field in self._event_dict
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self._event_dict.items()
|
||||||
|
|
||||||
|
|
||||||
class FrozenEvent(EventBase):
|
class FrozenEvent(EventBase):
|
||||||
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
||||||
|
|
|
@ -53,16 +53,54 @@ class BaseHandler(object):
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _filter_events_for_client(self, user_id, events, is_guest=False):
|
def _filter_events_for_clients(self, users, events):
|
||||||
# Assumes that user has at some point joined the room if not is_guest.
|
""" Returns dict of user_id -> list of events that user is allowed to
|
||||||
|
see.
|
||||||
|
"""
|
||||||
|
event_id_to_state = yield self.store.get_state_for_events(
|
||||||
|
frozenset(e.event_id for e in events),
|
||||||
|
types=(
|
||||||
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
|
(EventTypes.Member, None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
forgotten = yield defer.gatherResults([
|
||||||
|
self.store.who_forgot_in_room(
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
for room_id in frozenset(e.room_id for e in events)
|
||||||
|
], consumeErrors=True)
|
||||||
|
|
||||||
|
# Set of membership event_ids that have been forgotten
|
||||||
|
event_id_forgotten = frozenset(
|
||||||
|
row["event_id"] for rows in forgotten for row in rows
|
||||||
|
)
|
||||||
|
|
||||||
|
def allowed(event, user_id, is_guest):
|
||||||
|
state = event_id_to_state[event.event_id]
|
||||||
|
|
||||||
|
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||||
|
if visibility_event:
|
||||||
|
visibility = visibility_event.content.get("history_visibility", "shared")
|
||||||
|
else:
|
||||||
|
visibility = "shared"
|
||||||
|
|
||||||
def allowed(event, membership, visibility):
|
|
||||||
if visibility == "world_readable":
|
if visibility == "world_readable":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if is_guest:
|
if is_guest:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
membership_event = state.get((EventTypes.Member, user_id), None)
|
||||||
|
if membership_event:
|
||||||
|
if membership_event.event_id in event_id_forgotten:
|
||||||
|
membership = None
|
||||||
|
else:
|
||||||
|
membership = membership_event.membership
|
||||||
|
else:
|
||||||
|
membership = None
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -78,43 +116,20 @@ class BaseHandler(object):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
event_id_to_state = yield self.store.get_state_for_events(
|
defer.returnValue({
|
||||||
frozenset(e.event_id for e in events),
|
user_id: [
|
||||||
types=(
|
event
|
||||||
(EventTypes.RoomHistoryVisibility, ""),
|
for event in events
|
||||||
(EventTypes.Member, user_id),
|
if allowed(event, user_id, is_guest)
|
||||||
)
|
]
|
||||||
)
|
for user_id, is_guest in users
|
||||||
|
})
|
||||||
|
|
||||||
events_to_return = []
|
@defer.inlineCallbacks
|
||||||
for event in events:
|
def _filter_events_for_client(self, user_id, events, is_guest=False):
|
||||||
state = event_id_to_state[event.event_id]
|
# Assumes that user has at some point joined the room if not is_guest.
|
||||||
|
res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
|
||||||
membership_event = state.get((EventTypes.Member, user_id), None)
|
defer.returnValue(res.get(user_id, []))
|
||||||
if membership_event:
|
|
||||||
was_forgotten_at_event = yield self.store.was_forgotten_at(
|
|
||||||
membership_event.state_key,
|
|
||||||
membership_event.room_id,
|
|
||||||
membership_event.event_id
|
|
||||||
)
|
|
||||||
if was_forgotten_at_event:
|
|
||||||
membership = None
|
|
||||||
else:
|
|
||||||
membership = membership_event.membership
|
|
||||||
else:
|
|
||||||
membership = None
|
|
||||||
|
|
||||||
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
|
||||||
if visibility_event:
|
|
||||||
visibility = visibility_event.content.get("history_visibility", "shared")
|
|
||||||
else:
|
|
||||||
visibility = "shared"
|
|
||||||
|
|
||||||
should_include = allowed(event, membership, visibility)
|
|
||||||
if should_include:
|
|
||||||
events_to_return.append(event)
|
|
||||||
|
|
||||||
defer.returnValue(events_to_return)
|
|
||||||
|
|
||||||
def ratelimit(self, user_id):
|
def ratelimit(self, user_id):
|
||||||
time_now = self.clock.time()
|
time_now = self.clock.time()
|
||||||
|
|
|
@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
# from synapse.push.action_generator import ActionGenerator
|
from synapse.push.action_generator import ActionGenerator
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -244,12 +244,11 @@ class FederationHandler(BaseHandler):
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield user_joined_room(self.distributor, user, event.room_id)
|
||||||
|
|
||||||
# Temporarily disable notifications due to performance concerns.
|
if not backfilled and not event.internal_metadata.is_outlier():
|
||||||
# if not backfilled and not event.internal_metadata.is_outlier():
|
action_generator = ActionGenerator(self.store)
|
||||||
# action_generator = ActionGenerator(self.store)
|
yield action_generator.handle_push_actions_for_event(
|
||||||
# yield action_generator.handle_push_actions_for_event(
|
event, self
|
||||||
# event, self
|
)
|
||||||
# )
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _filter_events_for_server(self, server_name, room_id, events):
|
def _filter_events_for_server(self, server_name, room_id, events):
|
||||||
|
|
|
@ -841,9 +841,6 @@ class SyncHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
|
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
|
||||||
# Temporarily disable notifications due to performance concerns.
|
|
||||||
defer.returnValue([])
|
|
||||||
|
|
||||||
last_unread_event_id = self.last_read_event_id_for_room_and_user(
|
last_unread_event_id = self.last_read_event_id_for_room_and_user(
|
||||||
room_id, sync_config.user.to_string(), ephemeral_by_room
|
room_id, sync_config.user.to_string(), ephemeral_by_room
|
||||||
)
|
)
|
||||||
|
|
|
@ -36,9 +36,6 @@ class ActionGenerator:
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_push_actions_for_event(self, event, handler):
|
def handle_push_actions_for_event(self, event, handler):
|
||||||
# Temporarily disable notifications due to performance concerns.
|
|
||||||
return
|
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction and event.redacts is not None:
|
if event.type == EventTypes.Redaction and event.redacts is not None:
|
||||||
yield self.store.remove_push_actions_for_event_id(
|
yield self.store.remove_push_actions_for_event_id(
|
||||||
event.room_id, event.redacts
|
event.room_id, event.redacts
|
||||||
|
|
|
@ -15,27 +15,25 @@
|
||||||
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
||||||
|
|
||||||
|
|
||||||
def list_with_base_rules(rawrules, user_id):
|
def list_with_base_rules(rawrules):
|
||||||
ruleslist = []
|
ruleslist = []
|
||||||
|
|
||||||
# shove the server default rules for each kind onto the end of each
|
# shove the server default rules for each kind onto the end of each
|
||||||
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
|
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
|
||||||
|
|
||||||
ruleslist.extend(make_base_prepend_rules(
|
ruleslist.extend(make_base_prepend_rules(
|
||||||
user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||||
))
|
))
|
||||||
|
|
||||||
for r in rawrules:
|
for r in rawrules:
|
||||||
if r['priority_class'] < current_prio_class:
|
if r['priority_class'] < current_prio_class:
|
||||||
while r['priority_class'] < current_prio_class:
|
while r['priority_class'] < current_prio_class:
|
||||||
ruleslist.extend(make_base_append_rules(
|
ruleslist.extend(make_base_append_rules(
|
||||||
user_id,
|
|
||||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||||
))
|
))
|
||||||
current_prio_class -= 1
|
current_prio_class -= 1
|
||||||
if current_prio_class > 0:
|
if current_prio_class > 0:
|
||||||
ruleslist.extend(make_base_prepend_rules(
|
ruleslist.extend(make_base_prepend_rules(
|
||||||
user_id,
|
|
||||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -43,58 +41,47 @@ def list_with_base_rules(rawrules, user_id):
|
||||||
|
|
||||||
while current_prio_class > 0:
|
while current_prio_class > 0:
|
||||||
ruleslist.extend(make_base_append_rules(
|
ruleslist.extend(make_base_append_rules(
|
||||||
user_id,
|
|
||||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||||
))
|
))
|
||||||
current_prio_class -= 1
|
current_prio_class -= 1
|
||||||
if current_prio_class > 0:
|
if current_prio_class > 0:
|
||||||
ruleslist.extend(make_base_prepend_rules(
|
ruleslist.extend(make_base_prepend_rules(
|
||||||
user_id,
|
|
||||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||||
))
|
))
|
||||||
|
|
||||||
return ruleslist
|
return ruleslist
|
||||||
|
|
||||||
|
|
||||||
def make_base_append_rules(user, kind):
|
def make_base_append_rules(kind):
|
||||||
rules = []
|
rules = []
|
||||||
|
|
||||||
if kind == 'override':
|
if kind == 'override':
|
||||||
rules = make_base_append_override_rules()
|
rules = BASE_APPEND_OVRRIDE_RULES
|
||||||
elif kind == 'underride':
|
elif kind == 'underride':
|
||||||
rules = make_base_append_underride_rules(user)
|
rules = BASE_APPEND_UNDERRIDE_RULES
|
||||||
elif kind == 'content':
|
elif kind == 'content':
|
||||||
rules = make_base_append_content_rules(user)
|
rules = BASE_APPEND_CONTENT_RULES
|
||||||
|
|
||||||
for r in rules:
|
|
||||||
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
|
|
||||||
r['default'] = True # Deprecated, left for backwards compat
|
|
||||||
|
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
def make_base_prepend_rules(user, kind):
|
def make_base_prepend_rules(kind):
|
||||||
rules = []
|
rules = []
|
||||||
|
|
||||||
if kind == 'override':
|
if kind == 'override':
|
||||||
rules = make_base_prepend_override_rules()
|
rules = BASE_PREPEND_OVERRIDE_RULES
|
||||||
|
|
||||||
for r in rules:
|
|
||||||
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
|
|
||||||
r['default'] = True # Deprecated, left for backwards compat
|
|
||||||
|
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
def make_base_append_content_rules(user):
|
BASE_APPEND_CONTENT_RULES = [
|
||||||
return [
|
|
||||||
{
|
{
|
||||||
'rule_id': 'global/content/.m.rule.contains_user_name',
|
'rule_id': 'global/content/.m.rule.contains_user_name',
|
||||||
'conditions': [
|
'conditions': [
|
||||||
{
|
{
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'content.body',
|
'key': 'content.body',
|
||||||
'pattern': user.localpart, # Matrix ID match
|
'pattern_type': 'user_localpart'
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'actions': [
|
'actions': [
|
||||||
|
@ -110,8 +97,7 @@ def make_base_append_content_rules(user):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_base_prepend_override_rules():
|
BASE_PREPEND_OVERRIDE_RULES = [
|
||||||
return [
|
|
||||||
{
|
{
|
||||||
'rule_id': 'global/override/.m.rule.master',
|
'rule_id': 'global/override/.m.rule.master',
|
||||||
'enabled': False,
|
'enabled': False,
|
||||||
|
@ -123,8 +109,7 @@ def make_base_prepend_override_rules():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_base_append_override_rules():
|
BASE_APPEND_OVRRIDE_RULES = [
|
||||||
return [
|
|
||||||
{
|
{
|
||||||
'rule_id': 'global/override/.m.rule.suppress_notices',
|
'rule_id': 'global/override/.m.rule.suppress_notices',
|
||||||
'conditions': [
|
'conditions': [
|
||||||
|
@ -132,6 +117,7 @@ def make_base_append_override_rules():
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'content.msgtype',
|
'key': 'content.msgtype',
|
||||||
'pattern': 'm.notice',
|
'pattern': 'm.notice',
|
||||||
|
'_id': '_suppress_notices',
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'actions': [
|
'actions': [
|
||||||
|
@ -141,8 +127,7 @@ def make_base_append_override_rules():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_base_append_underride_rules(user):
|
BASE_APPEND_UNDERRIDE_RULES = [
|
||||||
return [
|
|
||||||
{
|
{
|
||||||
'rule_id': 'global/underride/.m.rule.call',
|
'rule_id': 'global/underride/.m.rule.call',
|
||||||
'conditions': [
|
'conditions': [
|
||||||
|
@ -150,6 +135,7 @@ def make_base_append_underride_rules(user):
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'type',
|
'key': 'type',
|
||||||
'pattern': 'm.call.invite',
|
'pattern': 'm.call.invite',
|
||||||
|
'_id': '_call',
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'actions': [
|
'actions': [
|
||||||
|
@ -185,7 +171,8 @@ def make_base_append_underride_rules(user):
|
||||||
'conditions': [
|
'conditions': [
|
||||||
{
|
{
|
||||||
'kind': 'room_member_count',
|
'kind': 'room_member_count',
|
||||||
'is': '2'
|
'is': '2',
|
||||||
|
'_id': 'member_count',
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'actions': [
|
'actions': [
|
||||||
|
@ -206,16 +193,18 @@ def make_base_append_underride_rules(user):
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'type',
|
'key': 'type',
|
||||||
'pattern': 'm.room.member',
|
'pattern': 'm.room.member',
|
||||||
|
'_id': '_member',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'content.membership',
|
'key': 'content.membership',
|
||||||
'pattern': 'invite',
|
'pattern': 'invite',
|
||||||
|
'_id': '_invite_member',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'state_key',
|
'key': 'state_key',
|
||||||
'pattern': user.to_string(),
|
'pattern_type': 'user_id'
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
'actions': [
|
'actions': [
|
||||||
|
@ -236,6 +225,7 @@ def make_base_append_underride_rules(user):
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'type',
|
'key': 'type',
|
||||||
'pattern': 'm.room.member',
|
'pattern': 'm.room.member',
|
||||||
|
'_id': '_member',
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'actions': [
|
'actions': [
|
||||||
|
@ -253,6 +243,7 @@ def make_base_append_underride_rules(user):
|
||||||
'kind': 'event_match',
|
'kind': 'event_match',
|
||||||
'key': 'type',
|
'key': 'type',
|
||||||
'pattern': 'm.room.message',
|
'pattern': 'm.room.message',
|
||||||
|
'_id': '_message',
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'actions': [
|
'actions': [
|
||||||
|
@ -263,3 +254,20 @@ def make_base_append_underride_rules(user):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
for r in BASE_APPEND_CONTENT_RULES:
|
||||||
|
r['priority_class'] = PRIORITY_CLASS_MAP['content']
|
||||||
|
r['default'] = True
|
||||||
|
|
||||||
|
for r in BASE_PREPEND_OVERRIDE_RULES:
|
||||||
|
r['priority_class'] = PRIORITY_CLASS_MAP['override']
|
||||||
|
r['default'] = True
|
||||||
|
|
||||||
|
for r in BASE_APPEND_OVRRIDE_RULES:
|
||||||
|
r['priority_class'] = PRIORITY_CLASS_MAP['override']
|
||||||
|
r['default'] = True
|
||||||
|
|
||||||
|
for r in BASE_APPEND_UNDERRIDE_RULES:
|
||||||
|
r['priority_class'] = PRIORITY_CLASS_MAP['underride']
|
||||||
|
r['default'] = True
|
||||||
|
|
|
@ -14,16 +14,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import simplejson as json
|
import ujson as json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import UserID
|
|
||||||
|
|
||||||
import baserules
|
import baserules
|
||||||
from push_rule_evaluator import PushRuleEvaluator
|
from push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes
|
||||||
|
|
||||||
from synapse.events.utils import serialize_event
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -34,29 +33,26 @@ def decode_rule_json(rule):
|
||||||
return rule
|
return rule
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_rules(room_id, user_ids, store):
|
||||||
|
rules_by_user = yield store.bulk_get_push_rules(user_ids)
|
||||||
|
rules_by_user = {
|
||||||
|
uid: baserules.list_with_base_rules([
|
||||||
|
decode_rule_json(rule_list)
|
||||||
|
for rule_list in rules_by_user.get(uid, [])
|
||||||
|
])
|
||||||
|
for uid in user_ids
|
||||||
|
}
|
||||||
|
defer.returnValue(rules_by_user)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def evaluator_for_room_id(room_id, store):
|
def evaluator_for_room_id(room_id, store):
|
||||||
users = yield store.get_users_in_room(room_id)
|
users = yield store.get_users_in_room(room_id)
|
||||||
rules_by_user = yield store.bulk_get_push_rules(users)
|
rules_by_user = yield _get_rules(room_id, users, store)
|
||||||
rules_by_user = {
|
|
||||||
uid: baserules.list_with_base_rules(
|
|
||||||
[decode_rule_json(rule_list) for rule_list in rules_by_user[uid]]
|
|
||||||
if uid in rules_by_user else [],
|
|
||||||
UserID.from_string(uid),
|
|
||||||
)
|
|
||||||
for uid in users
|
|
||||||
}
|
|
||||||
member_events = yield store.get_current_state(
|
|
||||||
room_id=room_id,
|
|
||||||
event_type='m.room.member',
|
|
||||||
)
|
|
||||||
display_names = {}
|
|
||||||
for ev in member_events:
|
|
||||||
if ev.content.get("displayname"):
|
|
||||||
display_names[ev.state_key] = ev.content.get("displayname")
|
|
||||||
|
|
||||||
defer.returnValue(BulkPushRuleEvaluator(
|
defer.returnValue(BulkPushRuleEvaluator(
|
||||||
room_id, rules_by_user, display_names, users, store
|
room_id, rules_by_user, users, store
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,10 +65,9 @@ class BulkPushRuleEvaluator:
|
||||||
the same logic to run the actual rules, but could be optimised further
|
the same logic to run the actual rules, but could be optimised further
|
||||||
(see https://matrix.org/jira/browse/SYN-562)
|
(see https://matrix.org/jira/browse/SYN-562)
|
||||||
"""
|
"""
|
||||||
def __init__(self, room_id, rules_by_user, display_names, users_in_room, store):
|
def __init__(self, room_id, rules_by_user, users_in_room, store):
|
||||||
self.room_id = room_id
|
self.room_id = room_id
|
||||||
self.rules_by_user = rules_by_user
|
self.rules_by_user = rules_by_user
|
||||||
self.display_names = display_names
|
|
||||||
self.users_in_room = users_in_room
|
self.users_in_room = users_in_room
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
|
@ -80,15 +75,30 @@ class BulkPushRuleEvaluator:
|
||||||
def action_for_event_by_user(self, event, handler):
|
def action_for_event_by_user(self, event, handler):
|
||||||
actions_by_user = {}
|
actions_by_user = {}
|
||||||
|
|
||||||
for uid, rules in self.rules_by_user.items():
|
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
|
||||||
display_name = None
|
|
||||||
if uid in self.display_names:
|
|
||||||
display_name = self.display_names[uid]
|
|
||||||
|
|
||||||
is_guest = yield self.store.is_guest(UserID.from_string(uid))
|
filtered_by_user = yield handler._filter_events_for_clients(
|
||||||
filtered = yield handler._filter_events_for_client(
|
users_dict.items(), [event]
|
||||||
uid, [event], is_guest=is_guest
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
|
||||||
|
|
||||||
|
condition_cache = {}
|
||||||
|
|
||||||
|
member_state = yield self.store.get_state_for_event(
|
||||||
|
event.event_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
display_names = {}
|
||||||
|
for ev in member_state.values():
|
||||||
|
nm = ev.content.get("displayname", None)
|
||||||
|
if nm and ev.type == EventTypes.Member:
|
||||||
|
display_names[ev.state_key] = nm
|
||||||
|
|
||||||
|
for uid, rules in self.rules_by_user.items():
|
||||||
|
display_name = display_names.get(uid, None)
|
||||||
|
|
||||||
|
filtered = filtered_by_user[uid]
|
||||||
if len(filtered) == 0:
|
if len(filtered) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -96,29 +106,32 @@ class BulkPushRuleEvaluator:
|
||||||
if 'enabled' in rule and not rule['enabled']:
|
if 'enabled' in rule and not rule['enabled']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# XXX: profile tags
|
matches = _condition_checker(
|
||||||
if BulkPushRuleEvaluator.event_matches_rule(
|
evaluator, rule['conditions'], uid, display_name, condition_cache
|
||||||
event, rule,
|
)
|
||||||
display_name, len(self.users_in_room), None
|
if matches:
|
||||||
):
|
|
||||||
actions = [x for x in rule['actions'] if x != 'dont_notify']
|
actions = [x for x in rule['actions'] if x != 'dont_notify']
|
||||||
if len(actions) > 0:
|
if actions:
|
||||||
actions_by_user[uid] = actions
|
actions_by_user[uid] = actions
|
||||||
break
|
break
|
||||||
defer.returnValue(actions_by_user)
|
defer.returnValue(actions_by_user)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def event_matches_rule(event, rule,
|
|
||||||
display_name, room_member_count, profile_tag):
|
|
||||||
matches = True
|
|
||||||
|
|
||||||
# passing the clock all the way into here is extremely awkward and push
|
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||||
# rules do not care about any of the relative timestamps, so we just
|
for cond in conditions:
|
||||||
# pass 0 for the current time.
|
_id = cond.get("_id", None)
|
||||||
client_event = serialize_event(event, 0)
|
if _id:
|
||||||
|
res = cache.get(_id, None)
|
||||||
|
if res is False:
|
||||||
|
break
|
||||||
|
elif res is True:
|
||||||
|
continue
|
||||||
|
|
||||||
for cond in rule['conditions']:
|
res = evaluator.matches(cond, uid, display_name, None)
|
||||||
matches &= PushRuleEvaluator._event_fulfills_condition(
|
if _id:
|
||||||
client_event, cond, display_name, room_member_count, profile_tag
|
cache[_id] = res
|
||||||
)
|
|
||||||
return matches
|
if res is False:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
|
@ -15,17 +15,22 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import UserID
|
|
||||||
|
|
||||||
import baserules
|
import baserules
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
|
||||||
|
IS_GLOB = re.compile(r'[\?\*\[\]]')
|
||||||
|
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
|
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
|
||||||
rawrules = yield store.get_push_rules_for_user(user_id)
|
rawrules = yield store.get_push_rules_for_user(user_id)
|
||||||
|
@ -42,9 +47,34 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def _room_member_count(ev, condition, room_member_count):
|
||||||
|
if 'is' not in condition:
|
||||||
|
return False
|
||||||
|
m = INEQUALITY_EXPR.match(condition['is'])
|
||||||
|
if not m:
|
||||||
|
return False
|
||||||
|
ineq = m.group(1)
|
||||||
|
rhs = m.group(2)
|
||||||
|
if not rhs.isdigit():
|
||||||
|
return False
|
||||||
|
rhs = int(rhs)
|
||||||
|
|
||||||
|
if ineq == '' or ineq == '==':
|
||||||
|
return room_member_count == rhs
|
||||||
|
elif ineq == '<':
|
||||||
|
return room_member_count < rhs
|
||||||
|
elif ineq == '>':
|
||||||
|
return room_member_count > rhs
|
||||||
|
elif ineq == '>=':
|
||||||
|
return room_member_count >= rhs
|
||||||
|
elif ineq == '<=':
|
||||||
|
return room_member_count <= rhs
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class PushRuleEvaluator:
|
class PushRuleEvaluator:
|
||||||
DEFAULT_ACTIONS = []
|
DEFAULT_ACTIONS = []
|
||||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
|
||||||
|
|
||||||
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
|
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
|
||||||
our_member_event, store):
|
our_member_event, store):
|
||||||
|
@ -61,8 +91,7 @@ class PushRuleEvaluator:
|
||||||
rule['actions'] = json.loads(raw_rule['actions'])
|
rule['actions'] = json.loads(raw_rule['actions'])
|
||||||
rules.append(rule)
|
rules.append(rule)
|
||||||
|
|
||||||
user = UserID.from_string(self.user_id)
|
self.rules = baserules.list_with_base_rules(rules)
|
||||||
self.rules = baserules.list_with_base_rules(rules, user)
|
|
||||||
|
|
||||||
self.enabled_map = enabled_map
|
self.enabled_map = enabled_map
|
||||||
|
|
||||||
|
@ -98,28 +127,19 @@ class PushRuleEvaluator:
|
||||||
room_members = yield self.store.get_users_in_room(room_id)
|
room_members = yield self.store.get_users_in_room(room_id)
|
||||||
room_member_count = len(room_members)
|
room_member_count = len(room_members)
|
||||||
|
|
||||||
|
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
|
||||||
|
|
||||||
for r in self.rules:
|
for r in self.rules:
|
||||||
if r['rule_id'] in self.enabled_map:
|
enabled = self.enabled_map.get(r['rule_id'], None)
|
||||||
r['enabled'] = self.enabled_map[r['rule_id']]
|
if enabled is not None and not enabled:
|
||||||
elif 'enabled' not in r:
|
continue
|
||||||
r['enabled'] = True
|
|
||||||
if not r['enabled']:
|
if not r.get("enabled", True):
|
||||||
continue
|
continue
|
||||||
matches = True
|
|
||||||
|
|
||||||
conditions = r['conditions']
|
conditions = r['conditions']
|
||||||
actions = r['actions']
|
actions = r['actions']
|
||||||
|
|
||||||
for c in conditions:
|
|
||||||
matches &= self._event_fulfills_condition(
|
|
||||||
ev, c, display_name=my_display_name,
|
|
||||||
room_member_count=room_member_count,
|
|
||||||
profile_tag=self.profile_tag
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Rule %s %s",
|
|
||||||
r['rule_id'], "matches" if matches else "doesn't match"
|
|
||||||
)
|
|
||||||
# ignore rules with no actions (we have an explict 'dont_notify')
|
# ignore rules with no actions (we have an explict 'dont_notify')
|
||||||
if len(actions) == 0:
|
if len(actions) == 0:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -127,8 +147,22 @@ class PushRuleEvaluator:
|
||||||
r['rule_id'], self.user_id
|
r['rule_id'], self.user_id
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
matches = True
|
||||||
|
for c in conditions:
|
||||||
|
matches = evaluator.matches(
|
||||||
|
c, self.user_id, my_display_name, self.profile_tag
|
||||||
|
)
|
||||||
|
if not matches:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Rule %s %s",
|
||||||
|
r['rule_id'], "matches" if matches else "doesn't match"
|
||||||
|
)
|
||||||
|
|
||||||
if matches:
|
if matches:
|
||||||
logger.info(
|
logger.debug(
|
||||||
"%s matches for user %s, event %s",
|
"%s matches for user %s, event %s",
|
||||||
r['rule_id'], self.user_id, ev['event_id']
|
r['rule_id'], self.user_id, ev['event_id']
|
||||||
)
|
)
|
||||||
|
@ -139,94 +173,132 @@ class PushRuleEvaluator:
|
||||||
|
|
||||||
defer.returnValue(actions)
|
defer.returnValue(actions)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
"No rules match for user %s, event %s",
|
"No rules match for user %s, event %s",
|
||||||
self.user_id, ev['event_id']
|
self.user_id, ev['event_id']
|
||||||
)
|
)
|
||||||
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
|
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _glob_to_regexp(glob):
|
|
||||||
r = re.escape(glob)
|
|
||||||
r = re.sub(r'\\\*', r'.*?', r)
|
|
||||||
r = re.sub(r'\\\?', r'.', r)
|
|
||||||
|
|
||||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
class PushRuleEvaluatorForEvent(object):
|
||||||
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
|
def __init__(self, event, room_member_count):
|
||||||
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
|
self._event = event
|
||||||
re.sub(r'\\\-', '-', x.group(2)))), r)
|
self._room_member_count = room_member_count
|
||||||
return r
|
|
||||||
|
|
||||||
@staticmethod
|
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||||
def _event_fulfills_condition(ev, condition,
|
self._value_cache = _flatten_dict(event)
|
||||||
display_name, room_member_count, profile_tag):
|
|
||||||
|
def matches(self, condition, user_id, display_name, profile_tag):
|
||||||
if condition['kind'] == 'event_match':
|
if condition['kind'] == 'event_match':
|
||||||
if 'pattern' not in condition:
|
return self._event_match(condition, user_id)
|
||||||
logger.warn("event_match condition with no pattern")
|
|
||||||
return False
|
|
||||||
# XXX: optimisation: cache our pattern regexps
|
|
||||||
if condition['key'] == 'content.body':
|
|
||||||
r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
|
|
||||||
else:
|
|
||||||
r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
|
|
||||||
val = _value_for_dotted_key(condition['key'], ev)
|
|
||||||
if val is None:
|
|
||||||
return False
|
|
||||||
return re.search(r, val, flags=re.IGNORECASE) is not None
|
|
||||||
|
|
||||||
elif condition['kind'] == 'device':
|
elif condition['kind'] == 'device':
|
||||||
if 'profile_tag' not in condition:
|
if 'profile_tag' not in condition:
|
||||||
return True
|
return True
|
||||||
return condition['profile_tag'] == profile_tag
|
return condition['profile_tag'] == profile_tag
|
||||||
|
|
||||||
elif condition['kind'] == 'contains_display_name':
|
elif condition['kind'] == 'contains_display_name':
|
||||||
# This is special because display names can be different
|
return self._contains_display_name(display_name)
|
||||||
# between rooms and so you can't really hard code it in a rule.
|
|
||||||
# Optimisation: we should cache these names and update them from
|
|
||||||
# the event stream.
|
|
||||||
if 'content' not in ev or 'body' not in ev['content']:
|
|
||||||
return False
|
|
||||||
if not display_name:
|
|
||||||
return False
|
|
||||||
return re.search(
|
|
||||||
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
|
|
||||||
flags=re.IGNORECASE
|
|
||||||
) is not None
|
|
||||||
|
|
||||||
elif condition['kind'] == 'room_member_count':
|
elif condition['kind'] == 'room_member_count':
|
||||||
if 'is' not in condition:
|
return _room_member_count(
|
||||||
return False
|
self._event, condition, self._room_member_count
|
||||||
m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is'])
|
)
|
||||||
if not m:
|
|
||||||
return False
|
|
||||||
ineq = m.group(1)
|
|
||||||
rhs = m.group(2)
|
|
||||||
if not rhs.isdigit():
|
|
||||||
return False
|
|
||||||
rhs = int(rhs)
|
|
||||||
|
|
||||||
if ineq == '' or ineq == '==':
|
|
||||||
return room_member_count == rhs
|
|
||||||
elif ineq == '<':
|
|
||||||
return room_member_count < rhs
|
|
||||||
elif ineq == '>':
|
|
||||||
return room_member_count > rhs
|
|
||||||
elif ineq == '>=':
|
|
||||||
return room_member_count >= rhs
|
|
||||||
elif ineq == '<=':
|
|
||||||
return room_member_count <= rhs
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _event_match(self, condition, user_id):
|
||||||
|
pattern = condition.get('pattern', None)
|
||||||
|
|
||||||
def _value_for_dotted_key(dotted_key, event):
|
if not pattern:
|
||||||
parts = dotted_key.split(".")
|
pattern_type = condition.get('pattern_type', None)
|
||||||
val = event
|
if pattern_type == "user_id":
|
||||||
while len(parts) > 0:
|
pattern = user_id
|
||||||
if parts[0] not in val:
|
elif pattern_type == "user_localpart":
|
||||||
return None
|
pattern = UserID.from_string(user_id).localpart
|
||||||
val = val[parts[0]]
|
|
||||||
parts = parts[1:]
|
if not pattern:
|
||||||
return val
|
logger.warn("event_match condition with no pattern")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# XXX: optimisation: cache our pattern regexps
|
||||||
|
if condition['key'] == 'content.body':
|
||||||
|
body = self._event["content"].get("body", None)
|
||||||
|
if not body:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return _glob_matches(pattern, body, word_boundary=True)
|
||||||
|
else:
|
||||||
|
haystack = self._get_value(condition['key'])
|
||||||
|
if haystack is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return _glob_matches(pattern, haystack)
|
||||||
|
|
||||||
|
def _contains_display_name(self, display_name):
|
||||||
|
if not display_name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
body = self._event["content"].get("body", None)
|
||||||
|
if not body:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return _glob_matches(display_name, body, word_boundary=True)
|
||||||
|
|
||||||
|
def _get_value(self, dotted_key):
|
||||||
|
return self._value_cache.get(dotted_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _glob_matches(glob, value, word_boundary=False):
|
||||||
|
"""Tests if value matches glob.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
glob (string)
|
||||||
|
value (string): String to test against glob.
|
||||||
|
word_boundary (bool): Whether to match against word boundaries or entire
|
||||||
|
string. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
|
"""
|
||||||
|
if IS_GLOB.search(glob):
|
||||||
|
r = re.escape(glob)
|
||||||
|
|
||||||
|
r = r.replace(r'\*', '.*?')
|
||||||
|
r = r.replace(r'\?', '.')
|
||||||
|
|
||||||
|
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||||
|
r = GLOB_REGEX.sub(
|
||||||
|
lambda x: (
|
||||||
|
'[%s%s]' % (
|
||||||
|
x.group(1) and '^' or '',
|
||||||
|
x.group(2).replace(r'\\\-', '-')
|
||||||
|
)
|
||||||
|
),
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
if word_boundary:
|
||||||
|
r = r"\b%s\b" % (r,)
|
||||||
|
r = re.compile(r, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
return r.search(value)
|
||||||
|
else:
|
||||||
|
r = r + "$"
|
||||||
|
r = re.compile(r, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
return r.match(value)
|
||||||
|
elif word_boundary:
|
||||||
|
r = re.escape(glob)
|
||||||
|
r = r"\b%s\b" % (r,)
|
||||||
|
r = re.compile(r, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
return r.search(value)
|
||||||
|
else:
|
||||||
|
return value.lower() == glob.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_dict(d, prefix=[], result={}):
|
||||||
|
for key, value in d.items():
|
||||||
|
if isinstance(value, basestring):
|
||||||
|
result[".".join(prefix + [key])] = value.lower()
|
||||||
|
elif hasattr(value, "items"):
|
||||||
|
_flatten_dict(value, prefix=(prefix+[key]), result=result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
|
@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
|
||||||
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import copy
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
rule["actions"] = json.loads(rawrule["actions"])
|
rule["actions"] = json.loads(rawrule["actions"])
|
||||||
ruleslist.append(rule)
|
ruleslist.append(rule)
|
||||||
|
|
||||||
ruleslist = baserules.list_with_base_rules(ruleslist, user)
|
# We're going to be mutating this a lot, so do a deep copy
|
||||||
|
ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
|
||||||
|
|
||||||
rules = {'global': {}, 'device': {}}
|
rules = {'global': {}, 'device': {}}
|
||||||
|
|
||||||
|
@ -140,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
template_name = _priority_class_to_template_name(r['priority_class'])
|
template_name = _priority_class_to_template_name(r['priority_class'])
|
||||||
|
|
||||||
|
# Remove internal stuff.
|
||||||
|
for c in r["conditions"]:
|
||||||
|
c.pop("_id", None)
|
||||||
|
|
||||||
|
pattern_type = c.pop("pattern_type", None)
|
||||||
|
if pattern_type == "user_id":
|
||||||
|
c["pattern"] = user.to_string()
|
||||||
|
elif pattern_type == "user_localpart":
|
||||||
|
c["pattern"] = user.localpart
|
||||||
|
|
||||||
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
|
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
|
||||||
# per-device rule
|
# per-device rule
|
||||||
profile_tag = _profile_tag_from_conditions(r["conditions"])
|
profile_tag = _profile_tag_from_conditions(r["conditions"])
|
||||||
|
|
|
@ -18,7 +18,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(SQLBaseStore):
|
||||||
|
@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def is_guest(self, user):
|
def is_guest(self, user_id):
|
||||||
res = yield self._simple_select_one_onecol(
|
res = yield self._simple_select_one_onecol(
|
||||||
table="users",
|
table="users",
|
||||||
keyvalues={"name": user.to_string()},
|
keyvalues={"name": user_id},
|
||||||
retcol="is_guest",
|
retcol="is_guest",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="is_guest",
|
desc="is_guest",
|
||||||
|
@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
|
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
|
||||||
|
inlineCallbacks=True)
|
||||||
|
def are_guests(self, user_ids):
|
||||||
|
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
|
||||||
|
",".join("?" for _ in user_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = yield self._execute(
|
||||||
|
"are_guests", self.cursor_to_dict, sql, *user_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {user_id: False for user_id in user_ids}
|
||||||
|
|
||||||
|
result.update({
|
||||||
|
row["name"]: bool(row["is_guest"])
|
||||||
|
for row in rows
|
||||||
|
})
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
def _query_for_auth(self, txn, token):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
|
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
|
||||||
|
|
|
@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
txn.execute(sql, (user_id, room_id))
|
txn.execute(sql, (user_id, room_id))
|
||||||
yield self.runInteraction("forget_membership", f)
|
yield self.runInteraction("forget_membership", f)
|
||||||
self.was_forgotten_at.invalidate_all()
|
self.was_forgotten_at.invalidate_all()
|
||||||
|
self.who_forgot_in_room.invalidate_all()
|
||||||
self.did_forget.invalidate((user_id, room_id))
|
self.did_forget.invalidate((user_id, room_id))
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2)
|
@cachedInlineCallbacks(num_args=2)
|
||||||
|
@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
forgot = yield self.runInteraction("did_forget_membership_at", f)
|
forgot = yield self.runInteraction("did_forget_membership_at", f)
|
||||||
defer.returnValue(forgot == 1)
|
defer.returnValue(forgot == 1)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def who_forgot_in_room(self, room_id):
|
||||||
|
return self._simple_select_list(
|
||||||
|
table="room_memberships",
|
||||||
|
retcols=("user_id", "event_id"),
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"forgotten": 1,
|
||||||
|
},
|
||||||
|
desc="who_forgot"
|
||||||
|
)
|
||||||
|
|
|
@ -1,141 +0,0 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from tests import unittest
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
from synapse.handlers.federation import FederationHandler
|
|
||||||
|
|
||||||
from mock import NonCallableMock, ANY, Mock
|
|
||||||
|
|
||||||
from ..utils import setup_test_homeserver
|
|
||||||
|
|
||||||
|
|
||||||
class FederationTestCase(unittest.TestCase):
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def setUp(self):
|
|
||||||
|
|
||||||
self.state_handler = NonCallableMock(spec_set=[
|
|
||||||
"compute_event_context",
|
|
||||||
])
|
|
||||||
|
|
||||||
self.auth = NonCallableMock(spec_set=[
|
|
||||||
"check",
|
|
||||||
"check_host_in_room",
|
|
||||||
])
|
|
||||||
|
|
||||||
self.hostname = "test"
|
|
||||||
hs = yield setup_test_homeserver(
|
|
||||||
self.hostname,
|
|
||||||
datastore=NonCallableMock(spec_set=[
|
|
||||||
"persist_event",
|
|
||||||
"store_room",
|
|
||||||
"get_room",
|
|
||||||
"get_destination_retry_timings",
|
|
||||||
"set_destination_retry_timings",
|
|
||||||
"have_events",
|
|
||||||
"get_users_in_room",
|
|
||||||
"bulk_get_push_rules",
|
|
||||||
"get_current_state",
|
|
||||||
"set_push_actions_for_event_and_users",
|
|
||||||
"is_guest",
|
|
||||||
"get_state_for_events",
|
|
||||||
]),
|
|
||||||
resource_for_federation=NonCallableMock(),
|
|
||||||
http_client=NonCallableMock(spec_set=[]),
|
|
||||||
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
|
|
||||||
handlers=NonCallableMock(spec_set=[
|
|
||||||
"room_member_handler",
|
|
||||||
"federation_handler",
|
|
||||||
]),
|
|
||||||
auth=self.auth,
|
|
||||||
state_handler=self.state_handler,
|
|
||||||
keyring=Mock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.datastore = hs.get_datastore()
|
|
||||||
self.handlers = hs.get_handlers()
|
|
||||||
self.notifier = hs.get_notifier()
|
|
||||||
self.hs = hs
|
|
||||||
|
|
||||||
self.handlers.federation_handler = FederationHandler(self.hs)
|
|
||||||
|
|
||||||
self.datastore.get_state_for_events.return_value = {"$a:b": {}}
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_msg(self):
|
|
||||||
pdu = FrozenEvent({
|
|
||||||
"type": EventTypes.Message,
|
|
||||||
"room_id": "foo",
|
|
||||||
"content": {"msgtype": u"fooo"},
|
|
||||||
"origin_server_ts": 0,
|
|
||||||
"event_id": "$a:b",
|
|
||||||
"user_id":"@a:b",
|
|
||||||
"origin": "b",
|
|
||||||
"auth_events": [],
|
|
||||||
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
|
||||||
})
|
|
||||||
|
|
||||||
self.datastore.persist_event.return_value = defer.succeed((1,1))
|
|
||||||
self.datastore.get_room.return_value = defer.succeed(True)
|
|
||||||
self.datastore.get_users_in_room.return_value = ["@a:b"]
|
|
||||||
self.datastore.bulk_get_push_rules.return_value = {}
|
|
||||||
self.datastore.get_current_state.return_value = {}
|
|
||||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
|
||||||
|
|
||||||
retry_timings_res = {
|
|
||||||
"destination": "",
|
|
||||||
"retry_last_ts": 0,
|
|
||||||
"retry_interval": 0,
|
|
||||||
}
|
|
||||||
self.datastore.get_destination_retry_timings.return_value = (
|
|
||||||
defer.succeed(retry_timings_res)
|
|
||||||
)
|
|
||||||
|
|
||||||
def have_events(event_ids):
|
|
||||||
return defer.succeed({})
|
|
||||||
self.datastore.have_events.side_effect = have_events
|
|
||||||
|
|
||||||
def annotate(ev, old_state=None, outlier=False):
|
|
||||||
context = Mock()
|
|
||||||
context.current_state = {}
|
|
||||||
context.auth_events = {}
|
|
||||||
return defer.succeed(context)
|
|
||||||
self.state_handler.compute_event_context.side_effect = annotate
|
|
||||||
|
|
||||||
yield self.handlers.federation_handler.on_receive_pdu(
|
|
||||||
"fo", pdu, False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.datastore.persist_event.assert_called_once_with(
|
|
||||||
ANY,
|
|
||||||
is_new_state=True,
|
|
||||||
backfilled=False,
|
|
||||||
current_state=None,
|
|
||||||
context=ANY,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.state_handler.compute_event_context.assert_called_once_with(
|
|
||||||
ANY, old_state=None, outlier=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
|
||||||
|
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
|
||||||
ANY, 1, 1, extra_users=[]
|
|
||||||
)
|
|
|
@ -1,418 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from .. import unittest
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
|
||||||
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
|
|
||||||
from synapse.handlers.profile import ProfileHandler
|
|
||||||
from synapse.types import UserID
|
|
||||||
from ..utils import setup_test_homeserver
|
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberHandlerTestCase(unittest.TestCase):
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def setUp(self):
|
|
||||||
self.hostname = "red"
|
|
||||||
hs = yield setup_test_homeserver(
|
|
||||||
self.hostname,
|
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
|
||||||
"send_message",
|
|
||||||
]),
|
|
||||||
datastore=NonCallableMock(spec_set=[
|
|
||||||
"persist_event",
|
|
||||||
"get_room_member",
|
|
||||||
"get_room",
|
|
||||||
"store_room",
|
|
||||||
"get_latest_events_in_room",
|
|
||||||
"add_event_hashes",
|
|
||||||
"get_users_in_room",
|
|
||||||
"bulk_get_push_rules",
|
|
||||||
"get_current_state",
|
|
||||||
"set_push_actions_for_event_and_users",
|
|
||||||
"get_state_for_events",
|
|
||||||
"is_guest",
|
|
||||||
]),
|
|
||||||
resource_for_federation=NonCallableMock(),
|
|
||||||
http_client=NonCallableMock(spec_set=[]),
|
|
||||||
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
|
|
||||||
handlers=NonCallableMock(spec_set=[
|
|
||||||
"room_member_handler",
|
|
||||||
"profile_handler",
|
|
||||||
"federation_handler",
|
|
||||||
]),
|
|
||||||
auth=NonCallableMock(spec_set=[
|
|
||||||
"check",
|
|
||||||
"add_auth_events",
|
|
||||||
"check_host_in_room",
|
|
||||||
]),
|
|
||||||
state_handler=NonCallableMock(spec_set=[
|
|
||||||
"compute_event_context",
|
|
||||||
"get_current_state",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.federation = NonCallableMock(spec_set=[
|
|
||||||
"handle_new_event",
|
|
||||||
"send_invite",
|
|
||||||
"get_state_for_room",
|
|
||||||
])
|
|
||||||
|
|
||||||
self.datastore = hs.get_datastore()
|
|
||||||
self.handlers = hs.get_handlers()
|
|
||||||
self.notifier = hs.get_notifier()
|
|
||||||
self.state_handler = hs.get_state_handler()
|
|
||||||
self.distributor = hs.get_distributor()
|
|
||||||
self.auth = hs.get_auth()
|
|
||||||
self.hs = hs
|
|
||||||
|
|
||||||
self.handlers.federation_handler = self.federation
|
|
||||||
|
|
||||||
self.distributor.declare("collect_presencelike_data")
|
|
||||||
|
|
||||||
self.handlers.room_member_handler = RoomMemberHandler(self.hs)
|
|
||||||
self.handlers.profile_handler = ProfileHandler(self.hs)
|
|
||||||
self.room_member_handler = self.handlers.room_member_handler
|
|
||||||
|
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
|
||||||
|
|
||||||
self.datastore.persist_event.return_value = (1,1)
|
|
||||||
self.datastore.add_event_hashes.return_value = []
|
|
||||||
self.datastore.get_users_in_room.return_value = ["@bob:red"]
|
|
||||||
self.datastore.bulk_get_push_rules.return_value = {}
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_invite(self):
|
|
||||||
room_id = "!foo:red"
|
|
||||||
user_id = "@bob:red"
|
|
||||||
target_user_id = "@red:blue"
|
|
||||||
content = {"membership": Membership.INVITE}
|
|
||||||
|
|
||||||
builder = self.hs.get_event_builder_factory().new({
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"sender": user_id,
|
|
||||||
"state_key": target_user_id,
|
|
||||||
"room_id": room_id,
|
|
||||||
"content": content,
|
|
||||||
})
|
|
||||||
|
|
||||||
self.datastore.get_latest_events_in_room.return_value = (
|
|
||||||
defer.succeed([])
|
|
||||||
)
|
|
||||||
self.datastore.get_current_state.return_value = {}
|
|
||||||
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
|
|
||||||
|
|
||||||
def annotate(_):
|
|
||||||
ctx = Mock()
|
|
||||||
ctx.current_state = {
|
|
||||||
(EventTypes.Member, "@alice:green"): self._create_member(
|
|
||||||
user_id="@alice:green",
|
|
||||||
room_id=room_id,
|
|
||||||
),
|
|
||||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
|
||||||
user_id="@bob:red",
|
|
||||||
room_id=room_id,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
ctx.prev_state_events = []
|
|
||||||
|
|
||||||
return defer.succeed(ctx)
|
|
||||||
|
|
||||||
self.state_handler.compute_event_context.side_effect = annotate
|
|
||||||
|
|
||||||
def add_auth(_, ctx):
|
|
||||||
ctx.auth_events = ctx.current_state[
|
|
||||||
(EventTypes.Member, "@bob:red")
|
|
||||||
]
|
|
||||||
|
|
||||||
return defer.succeed(True)
|
|
||||||
self.auth.add_auth_events.side_effect = add_auth
|
|
||||||
|
|
||||||
def send_invite(domain, event):
|
|
||||||
return defer.succeed(event)
|
|
||||||
|
|
||||||
self.federation.send_invite.side_effect = send_invite
|
|
||||||
|
|
||||||
room_handler = self.room_member_handler
|
|
||||||
event, context = yield room_handler._create_new_client_event(
|
|
||||||
builder
|
|
||||||
)
|
|
||||||
|
|
||||||
yield room_handler.send_membership_event(event, context)
|
|
||||||
|
|
||||||
self.state_handler.compute_event_context.assert_called_once_with(
|
|
||||||
builder
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth.add_auth_events.assert_called_once_with(
|
|
||||||
builder, context
|
|
||||||
)
|
|
||||||
|
|
||||||
self.federation.send_invite.assert_called_once_with(
|
|
||||||
"blue", event,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.datastore.persist_event.assert_called_once_with(
|
|
||||||
event, context=context,
|
|
||||||
)
|
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
|
||||||
event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
|
|
||||||
)
|
|
||||||
self.assertFalse(self.datastore.get_room.called)
|
|
||||||
self.assertFalse(self.datastore.store_room.called)
|
|
||||||
self.assertFalse(self.federation.get_state_for_room.called)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_simple_join(self):
|
|
||||||
room_id = "!foo:red"
|
|
||||||
user_id = "@bob:red"
|
|
||||||
user = UserID.from_string(user_id)
|
|
||||||
|
|
||||||
join_signal_observer = Mock()
|
|
||||||
self.distributor.observe("user_joined_room", join_signal_observer)
|
|
||||||
|
|
||||||
builder = self.hs.get_event_builder_factory().new({
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"sender": user_id,
|
|
||||||
"state_key": user_id,
|
|
||||||
"room_id": room_id,
|
|
||||||
"content": {"membership": Membership.JOIN},
|
|
||||||
})
|
|
||||||
|
|
||||||
self.datastore.get_latest_events_in_room.return_value = (
|
|
||||||
defer.succeed([])
|
|
||||||
)
|
|
||||||
self.datastore.get_current_state.return_value = {}
|
|
||||||
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
|
|
||||||
|
|
||||||
def annotate(_):
|
|
||||||
ctx = Mock()
|
|
||||||
ctx.current_state = {
|
|
||||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
|
||||||
user_id="@bob:red",
|
|
||||||
room_id=room_id,
|
|
||||||
membership=Membership.INVITE
|
|
||||||
),
|
|
||||||
}
|
|
||||||
ctx.prev_state_events = []
|
|
||||||
|
|
||||||
return defer.succeed(ctx)
|
|
||||||
|
|
||||||
self.state_handler.compute_event_context.side_effect = annotate
|
|
||||||
|
|
||||||
def add_auth(_, ctx):
|
|
||||||
ctx.auth_events = ctx.current_state[
|
|
||||||
(EventTypes.Member, "@bob:red")
|
|
||||||
]
|
|
||||||
|
|
||||||
return defer.succeed(True)
|
|
||||||
self.auth.add_auth_events.side_effect = add_auth
|
|
||||||
|
|
||||||
room_handler = self.room_member_handler
|
|
||||||
event, context = yield room_handler._create_new_client_event(
|
|
||||||
builder
|
|
||||||
)
|
|
||||||
|
|
||||||
# Actual invocation
|
|
||||||
yield room_handler.send_membership_event(event, context)
|
|
||||||
|
|
||||||
self.federation.handle_new_event.assert_called_once_with(
|
|
||||||
event, destinations=set()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.datastore.persist_event.assert_called_once_with(
|
|
||||||
event, context=context
|
|
||||||
)
|
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
|
||||||
event, 1, 1, extra_users=[user]
|
|
||||||
)
|
|
||||||
|
|
||||||
join_signal_observer.assert_called_with(
|
|
||||||
user=user, room_id=room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
|
|
||||||
builder = self.hs.get_event_builder_factory().new({
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"sender": user_id,
|
|
||||||
"state_key": user_id,
|
|
||||||
"room_id": room_id,
|
|
||||||
"content": {"membership": membership},
|
|
||||||
})
|
|
||||||
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_simple_leave(self):
|
|
||||||
room_id = "!foo:red"
|
|
||||||
user_id = "@bob:red"
|
|
||||||
user = UserID.from_string(user_id)
|
|
||||||
|
|
||||||
builder = self.hs.get_event_builder_factory().new({
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"sender": user_id,
|
|
||||||
"state_key": user_id,
|
|
||||||
"room_id": room_id,
|
|
||||||
"content": {"membership": Membership.LEAVE},
|
|
||||||
})
|
|
||||||
|
|
||||||
self.datastore.get_latest_events_in_room.return_value = (
|
|
||||||
defer.succeed([])
|
|
||||||
)
|
|
||||||
self.datastore.get_current_state.return_value = {}
|
|
||||||
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
|
|
||||||
|
|
||||||
def annotate(_):
|
|
||||||
ctx = Mock()
|
|
||||||
ctx.current_state = {
|
|
||||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
|
||||||
user_id="@bob:red",
|
|
||||||
room_id=room_id,
|
|
||||||
membership=Membership.JOIN
|
|
||||||
),
|
|
||||||
}
|
|
||||||
ctx.prev_state_events = []
|
|
||||||
|
|
||||||
return defer.succeed(ctx)
|
|
||||||
|
|
||||||
self.state_handler.compute_event_context.side_effect = annotate
|
|
||||||
|
|
||||||
def add_auth(_, ctx):
|
|
||||||
ctx.auth_events = ctx.current_state[
|
|
||||||
(EventTypes.Member, "@bob:red")
|
|
||||||
]
|
|
||||||
|
|
||||||
return defer.succeed(True)
|
|
||||||
self.auth.add_auth_events.side_effect = add_auth
|
|
||||||
|
|
||||||
room_handler = self.room_member_handler
|
|
||||||
event, context = yield room_handler._create_new_client_event(
|
|
||||||
builder
|
|
||||||
)
|
|
||||||
|
|
||||||
leave_signal_observer = Mock()
|
|
||||||
self.distributor.observe("user_left_room", leave_signal_observer)
|
|
||||||
|
|
||||||
# Actual invocation
|
|
||||||
yield room_handler.send_membership_event(event, context)
|
|
||||||
|
|
||||||
self.federation.handle_new_event.assert_called_once_with(
|
|
||||||
event, destinations=set(['red'])
|
|
||||||
)
|
|
||||||
|
|
||||||
self.datastore.persist_event.assert_called_once_with(
|
|
||||||
event, context=context
|
|
||||||
)
|
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
|
||||||
event, 1, 1, extra_users=[user]
|
|
||||||
)
|
|
||||||
|
|
||||||
leave_signal_observer.assert_called_with(
|
|
||||||
user=user, room_id=room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RoomCreationTest(unittest.TestCase):
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def setUp(self):
|
|
||||||
self.hostname = "red"
|
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
|
||||||
self.hostname,
|
|
||||||
datastore=NonCallableMock(spec_set=[
|
|
||||||
"store_room",
|
|
||||||
"snapshot_room",
|
|
||||||
"persist_event",
|
|
||||||
"get_joined_hosts_for_room",
|
|
||||||
]),
|
|
||||||
http_client=NonCallableMock(spec_set=[]),
|
|
||||||
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
|
|
||||||
handlers=NonCallableMock(spec_set=[
|
|
||||||
"room_creation_handler",
|
|
||||||
"message_handler",
|
|
||||||
]),
|
|
||||||
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
|
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
|
||||||
"send_message",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.federation = NonCallableMock(spec_set=[
|
|
||||||
"handle_new_event",
|
|
||||||
])
|
|
||||||
|
|
||||||
self.handlers = hs.get_handlers()
|
|
||||||
|
|
||||||
self.handlers.room_creation_handler = RoomCreationHandler(hs)
|
|
||||||
self.room_creation_handler = self.handlers.room_creation_handler
|
|
||||||
|
|
||||||
self.message_handler = self.handlers.message_handler
|
|
||||||
|
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_room_creation(self):
|
|
||||||
user_id = "@foo:red"
|
|
||||||
room_id = "!bobs_room:red"
|
|
||||||
config = {"visibility": "private"}
|
|
||||||
|
|
||||||
yield self.room_creation_handler.create_room(
|
|
||||||
user_id=user_id,
|
|
||||||
room_id=room_id,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(self.message_handler.create_and_send_event.called)
|
|
||||||
|
|
||||||
event_dicts = [
|
|
||||||
e[0][0]
|
|
||||||
for e in self.message_handler.create_and_send_event.call_args_list
|
|
||||||
]
|
|
||||||
|
|
||||||
self.assertTrue(len(event_dicts) > 3)
|
|
||||||
|
|
||||||
self.assertDictContainsSubset(
|
|
||||||
{
|
|
||||||
"type": EventTypes.Create,
|
|
||||||
"sender": user_id,
|
|
||||||
"room_id": room_id,
|
|
||||||
},
|
|
||||||
event_dicts[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
|
|
||||||
|
|
||||||
self.assertDictContainsSubset(
|
|
||||||
{
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"sender": user_id,
|
|
||||||
"room_id": room_id,
|
|
||||||
"state_key": user_id,
|
|
||||||
},
|
|
||||||
event_dicts[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
Membership.JOIN,
|
|
||||||
event_dicts[1]["content"]["membership"]
|
|
||||||
)
|
|
Loading…
Reference in New Issue