Merge pull request #825 from matrix-org/erikj/cache_push_rules
Load push rules in storage layer so that they get cached
This commit is contained in:
commit
722472b48c
|
@ -198,9 +198,8 @@ class SyncHandler(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def push_rules_for_user(self, user):
|
def push_rules_for_user(self, user):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
rawrules = yield self.store.get_push_rules_for_user(user_id)
|
rules = yield self.store.get_push_rules_for_user(user_id)
|
||||||
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
|
rules = format_push_rules_for_user(user, rules)
|
||||||
rules = format_push_rules_for_user(user, rawrules, enabled_map)
|
|
||||||
defer.returnValue(rules)
|
defer.returnValue(rules)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -18,7 +18,6 @@ import ujson as json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from .baserules import list_with_base_rules
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
@ -38,36 +37,9 @@ def decode_rule_json(rule):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_rules(room_id, user_ids, store):
|
def _get_rules(room_id, user_ids, store):
|
||||||
rules_by_user = yield store.bulk_get_push_rules(user_ids)
|
rules_by_user = yield store.bulk_get_push_rules(user_ids)
|
||||||
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
|
|
||||||
|
|
||||||
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
||||||
|
|
||||||
rules_by_user = {
|
|
||||||
uid: list_with_base_rules([
|
|
||||||
decode_rule_json(rule_list)
|
|
||||||
for rule_list in rules_by_user.get(uid, [])
|
|
||||||
])
|
|
||||||
for uid in user_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
# We apply the rules-enabled map here: bulk_get_push_rules doesn't
|
|
||||||
# fetch disabled rules, but this won't account for any server default
|
|
||||||
# rules the user has disabled, so we need to do this too.
|
|
||||||
for uid in user_ids:
|
|
||||||
user_enabled_map = rules_enabled_by_user.get(uid)
|
|
||||||
if not user_enabled_map:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for i, rule in enumerate(rules_by_user[uid]):
|
|
||||||
rule_id = rule['rule_id']
|
|
||||||
|
|
||||||
if rule_id in user_enabled_map:
|
|
||||||
if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
|
|
||||||
# Rules are cached across users.
|
|
||||||
rule = dict(rule)
|
|
||||||
rule['enabled'] = bool(user_enabled_map[rule_id])
|
|
||||||
rules_by_user[uid][i] = rule
|
|
||||||
|
|
||||||
defer.returnValue(rules_by_user)
|
defer.returnValue(rules_by_user)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,24 +51,26 @@ def evaluator_for_event(event, hs, store, current_state):
|
||||||
# generating them for bot / AS users etc, we only do so for people who've
|
# generating them for bot / AS users etc, we only do so for people who've
|
||||||
# sent a read receipt into the room.
|
# sent a read receipt into the room.
|
||||||
|
|
||||||
all_in_room = set(
|
local_users_in_room = set(
|
||||||
e.state_key for e in current_state.values()
|
e.state_key for e in current_state.values()
|
||||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||||
|
and hs.is_mine_id(e.state_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
# users in the room who have pushers need to get push rules run because
|
# users in the room who have pushers need to get push rules run because
|
||||||
# that's how their pushers work
|
# that's how their pushers work
|
||||||
if_users_with_pushers = yield store.get_if_users_have_pushers(all_in_room)
|
if_users_with_pushers = yield store.get_if_users_have_pushers(
|
||||||
users_with_pushers = set(
|
local_users_in_room
|
||||||
|
)
|
||||||
|
user_ids = set(
|
||||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||||
)
|
)
|
||||||
|
|
||||||
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
|
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
|
||||||
|
|
||||||
# any users with pushers must be ours: they have pushers
|
# any users with pushers must be ours: they have pushers
|
||||||
user_ids = set(users_with_pushers)
|
|
||||||
for uid in users_with_receipts:
|
for uid in users_with_receipts:
|
||||||
if hs.is_mine_id(uid) and uid in all_in_room:
|
if uid in local_users_in_room:
|
||||||
user_ids.add(uid)
|
user_ids.add(uid)
|
||||||
|
|
||||||
# if this event is an invite event, we may need to run rules for the user
|
# if this event is an invite event, we may need to run rules for the user
|
||||||
|
@ -108,8 +82,6 @@ def evaluator_for_event(event, hs, store, current_state):
|
||||||
if has_pusher:
|
if has_pusher:
|
||||||
user_ids.add(invited_user)
|
user_ids.add(invited_user)
|
||||||
|
|
||||||
user_ids = list(user_ids)
|
|
||||||
|
|
||||||
rules_by_user = yield _get_rules(room_id, user_ids, store)
|
rules_by_user = yield _get_rules(room_id, user_ids, store)
|
||||||
|
|
||||||
defer.returnValue(BulkPushRuleEvaluator(
|
defer.returnValue(BulkPushRuleEvaluator(
|
||||||
|
|
|
@ -23,10 +23,7 @@ import copy
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
|
||||||
|
|
||||||
def format_push_rules_for_user(user, rawrules, enabled_map):
|
def load_rules_for_user(user, rawrules, enabled_map):
|
||||||
"""Converts a list of rawrules and a enabled map into nested dictionaries
|
|
||||||
to match the Matrix client-server format for push rules"""
|
|
||||||
|
|
||||||
ruleslist = []
|
ruleslist = []
|
||||||
for rawrule in rawrules:
|
for rawrule in rawrules:
|
||||||
rule = dict(rawrule)
|
rule = dict(rawrule)
|
||||||
|
@ -35,7 +32,26 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
|
||||||
ruleslist.append(rule)
|
ruleslist.append(rule)
|
||||||
|
|
||||||
# We're going to be mutating this a lot, so do a deep copy
|
# We're going to be mutating this a lot, so do a deep copy
|
||||||
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
|
rules = list(list_with_base_rules(ruleslist))
|
||||||
|
|
||||||
|
for i, rule in enumerate(rules):
|
||||||
|
rule_id = rule['rule_id']
|
||||||
|
if rule_id in enabled_map:
|
||||||
|
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
|
||||||
|
# Rules are cached across users.
|
||||||
|
rule = dict(rule)
|
||||||
|
rule['enabled'] = bool(enabled_map[rule_id])
|
||||||
|
rules[i] = rule
|
||||||
|
|
||||||
|
return rules
|
||||||
|
|
||||||
|
|
||||||
|
def format_push_rules_for_user(user, ruleslist):
|
||||||
|
"""Converts a list of rawrules and a enabled map into nested dictionaries
|
||||||
|
to match the Matrix client-server format for push rules"""
|
||||||
|
|
||||||
|
# We're going to be mutating this a lot, so do a deep copy
|
||||||
|
ruleslist = copy.deepcopy(ruleslist)
|
||||||
|
|
||||||
rules = {'global': {}, 'device': {}}
|
rules = {'global': {}, 'device': {}}
|
||||||
|
|
||||||
|
@ -60,9 +76,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
|
||||||
|
|
||||||
template_rule = _rule_to_template(r)
|
template_rule = _rule_to_template(r)
|
||||||
if template_rule:
|
if template_rule:
|
||||||
if r['rule_id'] in enabled_map:
|
if 'enabled' in r:
|
||||||
template_rule['enabled'] = enabled_map[r['rule_id']]
|
|
||||||
elif 'enabled' in r:
|
|
||||||
template_rule['enabled'] = r['enabled']
|
template_rule['enabled'] = r['enabled']
|
||||||
else:
|
else:
|
||||||
template_rule['enabled'] = True
|
template_rule['enabled'] = True
|
||||||
|
|
|
@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
# we build up the full structure and then decide which bits of it
|
# we build up the full structure and then decide which bits of it
|
||||||
# to send which means doing unnecessary work sometimes but is
|
# to send which means doing unnecessary work sometimes but is
|
||||||
# is probably not going to make a whole lot of difference
|
# is probably not going to make a whole lot of difference
|
||||||
rawrules = yield self.store.get_push_rules_for_user(user_id)
|
rules = yield self.store.get_push_rules_for_user(user_id)
|
||||||
|
|
||||||
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
|
rules = format_push_rules_for_user(requester.user, rules)
|
||||||
|
|
||||||
rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
|
|
||||||
|
|
||||||
path = request.postpath[1:]
|
path = request.postpath[1:]
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||||
|
from synapse.push.baserules import list_with_base_rules
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -23,6 +24,29 @@ import simplejson as json
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_rules(rawrules, enabled_map):
|
||||||
|
ruleslist = []
|
||||||
|
for rawrule in rawrules:
|
||||||
|
rule = dict(rawrule)
|
||||||
|
rule["conditions"] = json.loads(rawrule["conditions"])
|
||||||
|
rule["actions"] = json.loads(rawrule["actions"])
|
||||||
|
ruleslist.append(rule)
|
||||||
|
|
||||||
|
# We're going to be mutating this a lot, so do a deep copy
|
||||||
|
rules = list(list_with_base_rules(ruleslist))
|
||||||
|
|
||||||
|
for i, rule in enumerate(rules):
|
||||||
|
rule_id = rule['rule_id']
|
||||||
|
if rule_id in enabled_map:
|
||||||
|
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
|
||||||
|
# Rules are cached across users.
|
||||||
|
rule = dict(rule)
|
||||||
|
rule['enabled'] = bool(enabled_map[rule_id])
|
||||||
|
rules[i] = rule
|
||||||
|
|
||||||
|
return rules
|
||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRuleStore(SQLBaseStore):
|
||||||
@cachedInlineCallbacks(lru=True)
|
@cachedInlineCallbacks(lru=True)
|
||||||
def get_push_rules_for_user(self, user_id):
|
def get_push_rules_for_user(self, user_id):
|
||||||
|
@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore):
|
||||||
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
|
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(rows)
|
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
|
||||||
|
|
||||||
|
rules = _load_rules(rows, enabled_map)
|
||||||
|
|
||||||
|
defer.returnValue(rules)
|
||||||
|
|
||||||
@cachedInlineCallbacks(lru=True)
|
@cachedInlineCallbacks(lru=True)
|
||||||
def get_push_rules_enabled_for_user(self, user_id):
|
def get_push_rules_enabled_for_user(self, user_id):
|
||||||
|
@ -85,6 +113,14 @@ class PushRuleStore(SQLBaseStore):
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
results.setdefault(row['user_name'], []).append(row)
|
results.setdefault(row['user_name'], []).append(row)
|
||||||
|
|
||||||
|
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
|
||||||
|
|
||||||
|
for user_id, rules in results.items():
|
||||||
|
results[user_id] = _load_rules(
|
||||||
|
rules, enabled_map_by_user.get(user_id, {})
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
|
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
|
||||||
|
|
|
@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
|
||||||
"get_all_updated_pushers", get_all_updated_pushers_txn
|
"get_all_updated_pushers", get_all_updated_pushers_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(lru=True, num_args=1)
|
@cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
|
||||||
def get_if_user_has_pusher(self, user_id):
|
def get_if_user_has_pusher(self, user_id):
|
||||||
result = yield self._simple_select_many_batch(
|
result = yield self._simple_select_many_batch(
|
||||||
table='pushers',
|
table='pushers',
|
||||||
|
|
Loading…
Reference in New Issue