Hook up the push rules to the notifier
This commit is contained in:
parent
2223204eba
commit
ddf9e7b302
|
@ -647,8 +647,8 @@ class MessageHandler(BaseHandler):
|
||||||
user_id, messages, is_peeking=is_peeking
|
user_id, messages, is_peeking=is_peeking
|
||||||
)
|
)
|
||||||
|
|
||||||
start_token = StreamToken(token[0], 0, 0, 0, 0)
|
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
|
||||||
end_token = StreamToken(token[1], 0, 0, 0, 0)
|
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
|
|
@ -284,7 +284,7 @@ class Notifier(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
|
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
|
||||||
from_token=StreamToken("s0", "0", "0", "0", "0")):
|
from_token=StreamToken.START):
|
||||||
"""Wait until the callback returns a non empty response or the
|
"""Wait until the callback returns a non empty response or the
|
||||||
timeout fires.
|
timeout fires.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
|
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
|
||||||
"Unrecognised request: You probably wanted a trailing slash")
|
"Unrecognised request: You probably wanted a trailing slash")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(PushRuleRestServlet, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request):
|
def on_PUT(self, request):
|
||||||
spec = _rule_spec_from_path(request.postpath)
|
spec = _rule_spec_from_path(request.postpath)
|
||||||
|
@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
if 'attr' in spec:
|
if 'attr' in spec:
|
||||||
yield self.set_rule_attr(requester.user.to_string(), spec, content)
|
yield self.set_rule_attr(user_id, spec, content)
|
||||||
|
self.notify_user(user_id)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
if spec['rule_id'].startswith('.'):
|
if spec['rule_id'].startswith('.'):
|
||||||
|
@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
after = _namespaced_rule_id(spec, after[0])
|
after = _namespaced_rule_id(spec, after[0])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.hs.get_datastore().add_push_rule(
|
yield self.store.add_push_rule(
|
||||||
user_id=requester.user.to_string(),
|
user_id=user_id,
|
||||||
rule_id=_namespaced_rule_id_from_spec(spec),
|
rule_id=_namespaced_rule_id_from_spec(spec),
|
||||||
priority_class=priority_class,
|
priority_class=priority_class,
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
|
@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
before=before,
|
before=before,
|
||||||
after=after
|
after=after
|
||||||
)
|
)
|
||||||
|
self.notify_user(user_id)
|
||||||
except InconsistentRuleException as e:
|
except InconsistentRuleException as e:
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, e.message)
|
||||||
except RuleNotFoundException as e:
|
except RuleNotFoundException as e:
|
||||||
|
@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
spec = _rule_spec_from_path(request.postpath)
|
spec = _rule_spec_from_path(request.postpath)
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.hs.get_datastore().delete_push_rule(
|
yield self.store.delete_push_rule(
|
||||||
requester.user.to_string(), namespaced_rule_id
|
user_id, namespaced_rule_id
|
||||||
)
|
)
|
||||||
|
self.notify_user(user_id)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
|
@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user = requester.user
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
# we build up the full structure and then decide which bits of it
|
# we build up the full structure and then decide which bits of it
|
||||||
# to send which means doing unnecessary work sometimes but is
|
# to send which means doing unnecessary work sometimes but is
|
||||||
# is probably not going to make a whole lot of difference
|
# is probably not going to make a whole lot of difference
|
||||||
rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
|
rawrules = yield self.store.get_push_rules_for_user(user_id)
|
||||||
user.to_string()
|
|
||||||
)
|
|
||||||
|
|
||||||
ruleslist = []
|
ruleslist = []
|
||||||
for rawrule in rawrules:
|
for rawrule in rawrules:
|
||||||
|
@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
|
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
|
||||||
|
|
||||||
enabled_map = yield self.hs.get_datastore().\
|
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
|
||||||
get_push_rules_enabled_for_user(user.to_string())
|
|
||||||
|
|
||||||
for r in ruleslist:
|
for r in ruleslist:
|
||||||
rulearray = None
|
rulearray = None
|
||||||
|
@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
pattern_type = c.pop("pattern_type", None)
|
pattern_type = c.pop("pattern_type", None)
|
||||||
if pattern_type == "user_id":
|
if pattern_type == "user_id":
|
||||||
c["pattern"] = user.to_string()
|
c["pattern"] = user_id
|
||||||
elif pattern_type == "user_localpart":
|
elif pattern_type == "user_localpart":
|
||||||
c["pattern"] = user.localpart
|
c["pattern"] = requester.user.localpart
|
||||||
|
|
||||||
rulearray = rules['global'][template_name]
|
rulearray = rules['global'][template_name]
|
||||||
|
|
||||||
|
@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
def on_OPTIONS(self, _):
|
def on_OPTIONS(self, _):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
def notify_user(self, user_id):
|
||||||
|
stream_id = self.store.get_push_rules_stream_token()
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"push_rules_key", stream_id, users=[user_id]
|
||||||
|
)
|
||||||
|
|
||||||
def set_rule_attr(self, user_id, spec, val):
|
def set_rule_attr(self, user_id, spec, val):
|
||||||
if spec['attr'] == 'enabled':
|
if spec['attr'] == 'enabled':
|
||||||
if isinstance(val, dict) and "enabled" in val:
|
if isinstance(val, dict) and "enabled" in val:
|
||||||
|
@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
# bools directly, so let's not break them.
|
# bools directly, so let's not break them.
|
||||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||||
return self.hs.get_datastore().set_push_rule_enabled(
|
return self.store.set_push_rule_enabled(
|
||||||
user_id, namespaced_rule_id, val
|
user_id, namespaced_rule_id, val
|
||||||
)
|
)
|
||||||
elif spec['attr'] == 'actions':
|
elif spec['attr'] == 'actions':
|
||||||
|
@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
if is_default_rule:
|
if is_default_rule:
|
||||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||||
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
|
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
|
||||||
return self.hs.get_datastore().set_push_rule_actions(
|
return self.store.set_push_rule_actions(
|
||||||
user_id, namespaced_rule_id, actions, is_default_rule
|
user_id, namespaced_rule_id, actions, is_default_rule
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -38,9 +38,12 @@ class EventSources(object):
|
||||||
name: cls(hs)
|
name: cls(hs)
|
||||||
for name, cls in EventSources.SOURCE_TYPES.items()
|
for name, cls in EventSources.SOURCE_TYPES.items()
|
||||||
}
|
}
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_token(self, direction='f'):
|
def get_current_token(self, direction='f'):
|
||||||
|
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||||
|
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=(
|
room_key=(
|
||||||
yield self.sources["room"].get_current_key(direction)
|
yield self.sources["room"].get_current_key(direction)
|
||||||
|
@ -57,5 +60,6 @@ class EventSources(object):
|
||||||
account_data_key=(
|
account_data_key=(
|
||||||
yield self.sources["account_data"].get_current_key()
|
yield self.sources["account_data"].get_current_key()
|
||||||
),
|
),
|
||||||
|
push_rules_key=push_rules_key,
|
||||||
)
|
)
|
||||||
defer.returnValue(token)
|
defer.returnValue(token)
|
||||||
|
|
|
@ -115,6 +115,7 @@ class StreamToken(
|
||||||
"typing_key",
|
"typing_key",
|
||||||
"receipt_key",
|
"receipt_key",
|
||||||
"account_data_key",
|
"account_data_key",
|
||||||
|
"push_rules_key",
|
||||||
))
|
))
|
||||||
):
|
):
|
||||||
_SEPARATOR = "_"
|
_SEPARATOR = "_"
|
||||||
|
@ -150,6 +151,7 @@ class StreamToken(
|
||||||
or (int(other.typing_key) < int(self.typing_key))
|
or (int(other.typing_key) < int(self.typing_key))
|
||||||
or (int(other.receipt_key) < int(self.receipt_key))
|
or (int(other.receipt_key) < int(self.receipt_key))
|
||||||
or (int(other.account_data_key) < int(self.account_data_key))
|
or (int(other.account_data_key) < int(self.account_data_key))
|
||||||
|
or (int(other.push_rules_key) < int(self.push_rules_key))
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy_and_advance(self, key, new_value):
|
def copy_and_advance(self, key, new_value):
|
||||||
|
@ -174,6 +176,11 @@ class StreamToken(
|
||||||
return StreamToken(**d)
|
return StreamToken(**d)
|
||||||
|
|
||||||
|
|
||||||
|
StreamToken.START = StreamToken(
|
||||||
|
*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue