Add caches to bulk_get_push_rules*

This commit is contained in:
Erik Johnston 2016-05-31 13:46:58 +01:00
parent b007ee4606
commit e5b0bbcd33
2 changed files with 15 additions and 9 deletions

View File

@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
def decode_rule_json(rule): def decode_rule_json(rule):
rule = dict(rule)
rule['conditions'] = json.loads(rule['conditions']) rule['conditions'] = json.loads(rule['conditions'])
rule['actions'] = json.loads(rule['actions']) rule['actions'] = json.loads(rule['actions'])
return rule return rule
@ -39,6 +40,8 @@ def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
rules_by_user = { rules_by_user = {
uid: list_with_base_rules([ uid: list_with_base_rules([
decode_rule_json(rule_list) decode_rule_json(rule_list)
@ -51,11 +54,10 @@ def _get_rules(room_id, user_ids, store):
# fetch disabled rules, but this won't account for any server default # fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too. # rules the user has disabled, so we need to do this too.
for uid in user_ids: for uid in user_ids:
if uid not in rules_enabled_by_user: user_enabled_map = rules_enabled_by_user.get(uid)
if not user_enabled_map:
continue continue
user_enabled_map = rules_enabled_by_user[uid]
for i, rule in enumerate(rules_by_user[uid]): for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id'] rule_id = rule['rule_id']

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks() @cachedInlineCallbacks(lru=True)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
@ -44,7 +44,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cachedInlineCallbacks() @cachedInlineCallbacks(lru=True)
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",
@ -60,7 +60,8 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results r['rule_id']: False if r['enabled'] == 0 else True for r in results
}) })
@defer.inlineCallbacks @cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids): def bulk_get_push_rules(self, user_ids):
if not user_ids: if not user_ids:
defer.returnValue({}) defer.returnValue({})
@ -75,13 +76,16 @@ class PushRuleStore(SQLBaseStore):
desc="bulk_get_push_rules", desc="bulk_get_push_rules",
) )
rows.sort(key=lambda e: (-e["priority_class"], -e["priority"])) rows.sort(
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
for row in rows: for row in rows:
results.setdefault(row['user_name'], []).append(row) results.setdefault(row['user_name'], []).append(row)
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks @cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids): def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids: if not user_ids:
defer.returnValue({}) defer.returnValue({})