Hook up the push rules to the notifier

This commit is contained in:
Mark Haines 2016-03-03 14:57:45 +00:00
parent 2223204eba
commit ddf9e7b302
5 changed files with 43 additions and 18 deletions

View File

@ -647,8 +647,8 @@ class MessageHandler(BaseHandler):
user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken(token[0], 0, 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0, 0)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()

View File

@ -284,7 +284,7 @@ class Notifier(object):
@defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")):
from_token=StreamToken.START):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""

View File

@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath)
@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request)
user_id = requester.user.to_string()
if 'attr' in spec:
yield self.set_rule_attr(requester.user.to_string(), spec, content)
yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
defer.returnValue((200, {}))
if spec['rule_id'].startswith('.'):
@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
after = _namespaced_rule_id(spec, after[0])
try:
yield self.hs.get_datastore().add_push_rule(
user_id=requester.user.to_string(),
yield self.store.add_push_rule(
user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
before=before,
after=after
)
self.notify_user(user_id)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
except RuleNotFoundException as e:
@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
spec = _rule_spec_from_path(request.postpath)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
yield self.hs.get_datastore().delete_push_rule(
requester.user.to_string(), namespaced_rule_id
yield self.store.delete_push_rule(
user_id, namespaced_rule_id
)
self.notify_user(user_id)
defer.returnValue((200, {}))
except StoreError as e:
if e.code == 404:
@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
user = requester.user
user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
user.to_string()
)
rawrules = yield self.store.get_push_rules_for_user(user_id)
ruleslist = []
for rawrule in rawrules:
@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
for r in ruleslist:
rulearray = None
@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
c["pattern"] = user_id
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
c["pattern"] = requester.user.localpart
rulearray = rules['global'][template_name]
@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _):
return 200, {}
def notify_user(self, user_id):
stream_id = self.store.get_push_rules_stream_token()
self.notifier.on_new_event(
"push_rules_key", stream_id, users=[user_id]
)
def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val:
@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return self.hs.get_datastore().set_push_rule_enabled(
return self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val
)
elif spec['attr'] == 'actions':
@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.hs.get_datastore().set_push_rule_actions(
return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:

View File

@ -38,9 +38,12 @@ class EventSources(object):
name: cls(hs)
for name, cls in EventSources.SOURCE_TYPES.items()
}
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_current_token(self, direction='f'):
push_rules_key, _ = self.store.get_push_rules_stream_token()
token = StreamToken(
room_key=(
yield self.sources["room"].get_current_key(direction)
@ -57,5 +60,6 @@ class EventSources(object):
account_data_key=(
yield self.sources["account_data"].get_current_key()
),
push_rules_key=push_rules_key,
)
defer.returnValue(token)

View File

@ -115,6 +115,7 @@ class StreamToken(
"typing_key",
"receipt_key",
"account_data_key",
"push_rules_key",
))
):
_SEPARATOR = "_"
@ -150,6 +151,7 @@ class StreamToken(
or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key))
)
def copy_and_advance(self, key, new_value):
@ -174,6 +176,11 @@ class StreamToken(
return StreamToken(**d)
StreamToken.START = StreamToken(
*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
)
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1.