Add type hints to the push module. (#8901)
This commit is contained in:
parent
a8eceb01e5
commit
5d34f40d49
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to push module.
|
7
mypy.ini
7
mypy.ini
|
@ -56,12 +56,7 @@ files =
|
||||||
synapse/metrics,
|
synapse/metrics,
|
||||||
synapse/module_api,
|
synapse/module_api,
|
||||||
synapse/notifier.py,
|
synapse/notifier.py,
|
||||||
synapse/push/emailpusher.py,
|
synapse/push,
|
||||||
synapse/push/httppusher.py,
|
|
||||||
synapse/push/mailer.py,
|
|
||||||
synapse/push/pusher.py,
|
|
||||||
synapse/push/pusherpool.py,
|
|
||||||
synapse/push/push_rule_evaluator.py,
|
|
||||||
synapse/replication,
|
synapse/replication,
|
||||||
synapse/rest,
|
synapse/rest,
|
||||||
synapse/server.py,
|
synapse/server.py,
|
||||||
|
|
|
@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
|
||||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||||
if fullname.startswith(
|
if fullname.startswith(
|
||||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
||||||
|
) or fullname.startswith(
|
||||||
|
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
|
||||||
):
|
):
|
||||||
return cached_function_method_signature
|
return cached_function_method_signature
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -14,19 +14,22 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ActionGenerator:
|
class ActionGenerator:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
|
||||||
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
|
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
|
||||||
# really we want to get all user ids and all profile tags too,
|
# really we want to get all user ids and all profile tags too,
|
||||||
# since we want the actions for each profile tag for every user and
|
# since we want the actions for each profile tag for every user and
|
||||||
|
@ -35,6 +38,8 @@ class ActionGenerator:
|
||||||
# event stream, so we just run the rules for a client with no profile
|
# event stream, so we just run the rules for a client with no profile
|
||||||
# tag (ie. we just need all the users).
|
# tag (ie. we just need all the users).
|
||||||
|
|
||||||
async def handle_push_actions_for_event(self, event, context):
|
async def handle_push_actions_for_event(
|
||||||
|
self, event: EventBase, context: EventContext
|
||||||
|
) -> None:
|
||||||
with Measure(self.clock, "action_for_event_by_user"):
|
with Measure(self.clock, "action_for_event_by_user"):
|
||||||
await self.bulk_evaluator.action_for_event_by_user(event, context)
|
await self.bulk_evaluator.action_for_event_by_user(event, context)
|
||||||
|
|
|
@ -15,16 +15,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
||||||
|
|
||||||
|
|
||||||
def list_with_base_rules(rawrules, use_new_defaults=False):
|
def list_with_base_rules(
|
||||||
|
rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Combine the list of rules set by the user with the default push rules
|
"""Combine the list of rules set by the user with the default push rules
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rawrules(list): The rules the user has modified or set.
|
rawrules: The rules the user has modified or set.
|
||||||
use_new_defaults(bool): Whether to use the new experimental default rules when
|
use_new_defaults: Whether to use the new experimental default rules when
|
||||||
appending or prepending default rules.
|
appending or prepending default rules.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
|
||||||
return ruleslist
|
return ruleslist
|
||||||
|
|
||||||
|
|
||||||
def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
def make_base_append_rules(
|
||||||
|
kind: str,
|
||||||
|
modified_base_rules: Dict[str, Dict[str, Any]],
|
||||||
|
use_new_defaults: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
rules = []
|
rules = []
|
||||||
|
|
||||||
if kind == "override":
|
if kind == "override":
|
||||||
|
@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||||
rules = copy.deepcopy(rules)
|
rules = copy.deepcopy(rules)
|
||||||
for r in rules:
|
for r in rules:
|
||||||
# Only modify the actions, keep the conditions the same.
|
# Only modify the actions, keep the conditions the same.
|
||||||
|
assert isinstance(r["rule_id"], str)
|
||||||
modified = modified_base_rules.get(r["rule_id"])
|
modified = modified_base_rules.get(r["rule_id"])
|
||||||
if modified:
|
if modified:
|
||||||
r["actions"] = modified["actions"]
|
r["actions"] = modified["actions"]
|
||||||
|
@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
|
def make_base_prepend_rules(
|
||||||
|
kind: str,
|
||||||
|
modified_base_rules: Dict[str, Dict[str, Any]],
|
||||||
|
use_new_defaults: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
rules = []
|
rules = []
|
||||||
|
|
||||||
if kind == "override":
|
if kind == "override":
|
||||||
|
@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||||
rules = copy.deepcopy(rules)
|
rules = copy.deepcopy(rules)
|
||||||
for r in rules:
|
for r in rules:
|
||||||
# Only modify the actions, keep the conditions the same.
|
# Only modify the actions, keep the conditions the same.
|
||||||
|
assert isinstance(r["rule_id"], str)
|
||||||
modified = modified_base_rules.get(r["rule_id"])
|
modified = modified_base_rules.get(r["rule_id"])
|
||||||
if modified:
|
if modified:
|
||||||
r["actions"] = modified["actions"]
|
r["actions"] = modified["actions"]
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
@ -25,18 +26,18 @@ from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.state import POWER_KEY
|
from synapse.state import POWER_KEY
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import CacheMetric, register_cache
|
||||||
from synapse.util.caches.descriptors import lru_cache
|
from synapse.util.caches.descriptors import lru_cache
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
rules_by_room = {}
|
|
||||||
|
|
||||||
|
|
||||||
push_rules_invalidation_counter = Counter(
|
push_rules_invalidation_counter = Counter(
|
||||||
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
|
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
|
||||||
)
|
)
|
||||||
|
@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
|
||||||
room at once.
|
room at once.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
|
||||||
resizable=False,
|
resizable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_rules_for_event(self, event, context):
|
async def _get_rules_for_event(
|
||||||
|
self, event: EventBase, context: EventContext
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""This gets the rules for all users in the room at the time of the event,
|
"""This gets the rules for all users in the room at the time of the event,
|
||||||
as well as the push rules for the invitee if the event is an invite.
|
as well as the push rules for the invitee if the event is an invite.
|
||||||
|
|
||||||
|
@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
|
||||||
return rules_by_user
|
return rules_by_user
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def _get_rules_for_room(self, room_id):
|
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
|
||||||
"""Get the current RulesForRoom object for the given room id
|
"""Get the current RulesForRoom object for the given room id
|
||||||
|
|
||||||
Returns:
|
|
||||||
RulesForRoom
|
|
||||||
"""
|
"""
|
||||||
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
||||||
# before any lookup methods get called on it as otherwise there may be
|
# before any lookup methods get called on it as otherwise there may be
|
||||||
|
@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
|
||||||
self.room_push_rule_cache_metrics,
|
self.room_push_rule_cache_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_power_levels_and_sender_level(self, event, context):
|
async def _get_power_levels_and_sender_level(
|
||||||
|
self, event: EventBase, context: EventContext
|
||||||
|
) -> Tuple[dict, int]:
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||||
if pl_event_id:
|
if pl_event_id:
|
||||||
# fastpath: if there's a power level event, that's all we need, and
|
# fastpath: if there's a power level event, that's all we need, and
|
||||||
# not having a power level event is an extreme edge case
|
# not having a power level event is an extreme edge case
|
||||||
pl_event = await self.store.get_event(pl_event_id)
|
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
|
||||||
auth_events = {POWER_KEY: pl_event}
|
|
||||||
else:
|
else:
|
||||||
auth_events_ids = self.auth.compute_auth_events(
|
auth_events_ids = self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=False
|
event, prev_state_ids, for_verification=False
|
||||||
)
|
)
|
||||||
auth_events = await self.store.get_events(auth_events_ids)
|
auth_events_dict = await self.store.get_events(auth_events_ids)
|
||||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
|
||||||
|
|
||||||
sender_level = get_user_power_level(event.sender, auth_events)
|
sender_level = get_user_power_level(event.sender, auth_events)
|
||||||
|
|
||||||
|
@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
|
||||||
|
|
||||||
return pl_event.content if pl_event else {}, sender_level
|
return pl_event.content if pl_event else {}, sender_level
|
||||||
|
|
||||||
async def action_for_event_by_user(self, event, context) -> None:
|
async def action_for_event_by_user(
|
||||||
|
self, event: EventBase, context: EventContext
|
||||||
|
) -> None:
|
||||||
"""Given an event and context, evaluate the push rules, check if the message
|
"""Given an event and context, evaluate the push rules, check if the message
|
||||||
should increment the unread count, and insert the results into the
|
should increment the unread count, and insert the results into the
|
||||||
event_push_actions_staging table.
|
event_push_actions_staging table.
|
||||||
|
@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
|
||||||
count_as_unread = _should_count_as_unread(event, context)
|
count_as_unread = _should_count_as_unread(event, context)
|
||||||
|
|
||||||
rules_by_user = await self._get_rules_for_event(event, context)
|
rules_by_user = await self._get_rules_for_event(event, context)
|
||||||
actions_by_user = {}
|
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
|
||||||
|
|
||||||
room_members = await self.store.get_joined_users_from_context(event, context)
|
room_members = await self.store.get_joined_users_from_context(event, context)
|
||||||
|
|
||||||
|
@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
|
||||||
event, len(room_members), sender_power_level, power_levels
|
event, len(room_members), sender_power_level, power_levels
|
||||||
)
|
)
|
||||||
|
|
||||||
condition_cache = {}
|
condition_cache = {} # type: Dict[str, bool]
|
||||||
|
|
||||||
for uid, rules in rules_by_user.items():
|
for uid, rules in rules_by_user.items():
|
||||||
if event.sender == uid:
|
if event.sender == uid:
|
||||||
|
@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
def _condition_checker(
|
||||||
|
evaluator: PushRuleEvaluatorForEvent,
|
||||||
|
conditions: List[dict],
|
||||||
|
uid: str,
|
||||||
|
display_name: str,
|
||||||
|
cache: Dict[str, bool],
|
||||||
|
) -> bool:
|
||||||
for cond in conditions:
|
for cond in conditions:
|
||||||
_id = cond.get("_id", None)
|
_id = cond.get("_id", None)
|
||||||
if _id:
|
if _id:
|
||||||
|
@ -277,15 +286,19 @@ class RulesForRoom:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
room_id: str,
|
||||||
|
rules_for_room_cache: LruCache,
|
||||||
|
room_push_rule_cache_metrics: CacheMetric,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hs (HomeServer)
|
hs: The HomeServer object.
|
||||||
room_id (str)
|
room_id: The room ID.
|
||||||
rules_for_room_cache: The cache object that caches these
|
rules_for_room_cache: The cache object that caches these
|
||||||
RoomsForUser objects.
|
RoomsForUser objects.
|
||||||
room_push_rule_cache_metrics (CacheMetric)
|
room_push_rule_cache_metrics: The metrics object
|
||||||
"""
|
"""
|
||||||
self.room_id = room_id
|
self.room_id = room_id
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
@ -294,8 +307,10 @@ class RulesForRoom:
|
||||||
|
|
||||||
self.linearizer = Linearizer(name="rules_for_room")
|
self.linearizer = Linearizer(name="rules_for_room")
|
||||||
|
|
||||||
self.member_map = {} # event_id -> (user_id, state)
|
# event_id -> (user_id, state)
|
||||||
self.rules_by_user = {} # user_id -> rules
|
self.member_map = {} # type: Dict[str, Tuple[str, str]]
|
||||||
|
# user_id -> rules
|
||||||
|
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
|
||||||
|
|
||||||
# The last state group we updated the caches for. If the state_group of
|
# The last state group we updated the caches for. If the state_group of
|
||||||
# a new event comes along, we know that we can just return the cached
|
# a new event comes along, we know that we can just return the cached
|
||||||
|
@ -315,7 +330,7 @@ class RulesForRoom:
|
||||||
# calculate push for)
|
# calculate push for)
|
||||||
# These never need to be invalidated as we will never set up push for
|
# These never need to be invalidated as we will never set up push for
|
||||||
# them.
|
# them.
|
||||||
self.uninteresting_user_set = set()
|
self.uninteresting_user_set = set() # type: Set[str]
|
||||||
|
|
||||||
# We need to be clever on the invalidating caches callbacks, as
|
# We need to be clever on the invalidating caches callbacks, as
|
||||||
# otherwise the invalidation callback holds a reference to the object,
|
# otherwise the invalidation callback holds a reference to the object,
|
||||||
|
@ -325,7 +340,9 @@ class RulesForRoom:
|
||||||
# to self around in the callback.
|
# to self around in the callback.
|
||||||
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
|
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
|
||||||
|
|
||||||
async def get_rules(self, event, context):
|
async def get_rules(
|
||||||
|
self, event: EventBase, context: EventContext
|
||||||
|
) -> Dict[str, List[Dict[str, dict]]]:
|
||||||
"""Given an event context return the rules for all users who are
|
"""Given an event context return the rules for all users who are
|
||||||
currently in the room.
|
currently in the room.
|
||||||
"""
|
"""
|
||||||
|
@ -356,6 +373,8 @@ class RulesForRoom:
|
||||||
else:
|
else:
|
||||||
current_state_ids = await context.get_current_state_ids()
|
current_state_ids = await context.get_current_state_ids()
|
||||||
push_rules_delta_state_cache_metric.inc_misses()
|
push_rules_delta_state_cache_metric.inc_misses()
|
||||||
|
# Ensure the state IDs exist.
|
||||||
|
assert current_state_ids is not None
|
||||||
|
|
||||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||||
|
|
||||||
|
@ -420,18 +439,23 @@ class RulesForRoom:
|
||||||
return ret_rules_by_user
|
return ret_rules_by_user
|
||||||
|
|
||||||
async def _update_rules_with_member_event_ids(
|
async def _update_rules_with_member_event_ids(
|
||||||
self, ret_rules_by_user, member_event_ids, state_group, event
|
self,
|
||||||
):
|
ret_rules_by_user: Dict[str, list],
|
||||||
|
member_event_ids: Dict[str, str],
|
||||||
|
state_group: Optional[int],
|
||||||
|
event: EventBase,
|
||||||
|
) -> None:
|
||||||
"""Update the partially filled rules_by_user dict by fetching rules for
|
"""Update the partially filled rules_by_user dict by fetching rules for
|
||||||
any newly joined users in the `member_event_ids` list.
|
any newly joined users in the `member_event_ids` list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
|
ret_rules_by_user: Partially filled dict of push rules. Gets
|
||||||
updated with any new rules.
|
updated with any new rules.
|
||||||
member_event_ids (dict): Dict of user id to event id for membership events
|
member_event_ids: Dict of user id to event id for membership events
|
||||||
that have happened since the last time we filled rules_by_user
|
that have happened since the last time we filled rules_by_user
|
||||||
state_group: The state group we are currently computing push rules
|
state_group: The state group we are currently computing push rules
|
||||||
for. Used when updating the cache.
|
for. Used when updating the cache.
|
||||||
|
event: The event we are currently computing push rules for.
|
||||||
"""
|
"""
|
||||||
sequence = self.sequence
|
sequence = self.sequence
|
||||||
|
|
||||||
|
@ -449,19 +473,19 @@ class RulesForRoom:
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||||
|
|
||||||
user_ids = {
|
joined_user_ids = {
|
||||||
user_id
|
user_id
|
||||||
for user_id, membership in members.values()
|
for user_id, membership in members.values()
|
||||||
if membership == Membership.JOIN
|
if membership == Membership.JOIN
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug("Joined: %r", user_ids)
|
logger.debug("Joined: %r", joined_user_ids)
|
||||||
|
|
||||||
# Previously we only considered users with pushers or read receipts in that
|
# Previously we only considered users with pushers or read receipts in that
|
||||||
# room. We can't do this anymore because we use push actions to calculate unread
|
# room. We can't do this anymore because we use push actions to calculate unread
|
||||||
# counts, which don't rely on the user having pushers or sent a read receipt into
|
# counts, which don't rely on the user having pushers or sent a read receipt into
|
||||||
# the room. Therefore we just need to filter for local users here.
|
# the room. Therefore we just need to filter for local users here.
|
||||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
user_ids = list(filter(self.is_mine_id, joined_user_ids))
|
||||||
|
|
||||||
rules_by_user = await self.store.bulk_get_push_rules(
|
rules_by_user = await self.store.bulk_get_push_rules(
|
||||||
user_ids, on_invalidate=self.invalidate_all_cb
|
user_ids, on_invalidate=self.invalidate_all_cb
|
||||||
|
@ -473,7 +497,7 @@ class RulesForRoom:
|
||||||
|
|
||||||
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
||||||
|
|
||||||
def invalidate_all(self):
|
def invalidate_all(self) -> None:
|
||||||
# Note: Don't hand this function directly to an invalidation callback
|
# Note: Don't hand this function directly to an invalidation callback
|
||||||
# as it keeps a reference to self and will stop this instance from being
|
# as it keeps a reference to self and will stop this instance from being
|
||||||
# GC'd if it gets dropped from the rules_to_user cache. Instead use
|
# GC'd if it gets dropped from the rules_to_user cache. Instead use
|
||||||
|
@ -485,7 +509,7 @@ class RulesForRoom:
|
||||||
self.rules_by_user = {}
|
self.rules_by_user = {}
|
||||||
push_rules_invalidation_counter.inc()
|
push_rules_invalidation_counter.inc()
|
||||||
|
|
||||||
def update_cache(self, sequence, members, rules_by_user, state_group):
|
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
|
||||||
if sequence == self.sequence:
|
if sequence == self.sequence:
|
||||||
self.member_map.update(members)
|
self.member_map.update(members)
|
||||||
self.rules_by_user = rules_by_user
|
self.rules_by_user = rules_by_user
|
||||||
|
@ -506,7 +530,7 @@ class _Invalidation:
|
||||||
cache = attr.ib(type=LruCache)
|
cache = attr.ib(type=LruCache)
|
||||||
room_id = attr.ib(type=str)
|
room_id = attr.ib(type=str)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self) -> None:
|
||||||
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||||
if rules:
|
if rules:
|
||||||
rules.invalidate_all()
|
rules.invalidate_all()
|
||||||
|
|
|
@ -14,24 +14,27 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
|
||||||
def format_push_rules_for_user(user, ruleslist):
|
def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
|
||||||
"""Converts a list of rawrules and a enabled map into nested dictionaries
|
"""Converts a list of rawrules and a enabled map into nested dictionaries
|
||||||
to match the Matrix client-server format for push rules"""
|
to match the Matrix client-server format for push rules"""
|
||||||
|
|
||||||
# We're going to be mutating this a lot, so do a deep copy
|
# We're going to be mutating this a lot, so do a deep copy
|
||||||
ruleslist = copy.deepcopy(ruleslist)
|
ruleslist = copy.deepcopy(ruleslist)
|
||||||
|
|
||||||
rules = {"global": {}, "device": {}}
|
rules = {
|
||||||
|
"global": {},
|
||||||
|
"device": {},
|
||||||
|
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
|
||||||
|
|
||||||
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
|
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
|
||||||
|
|
||||||
for r in ruleslist:
|
for r in ruleslist:
|
||||||
rulearray = None
|
|
||||||
|
|
||||||
template_name = _priority_class_to_template_name(r["priority_class"])
|
template_name = _priority_class_to_template_name(r["priority_class"])
|
||||||
|
|
||||||
# Remove internal stuff.
|
# Remove internal stuff.
|
||||||
|
@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
def _add_empty_priority_class_arrays(d):
|
def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
|
||||||
for pc in PRIORITY_CLASS_MAP.keys():
|
for pc in PRIORITY_CLASS_MAP.keys():
|
||||||
d[pc] = []
|
d[pc] = []
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def _rule_to_template(rule):
|
def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
unscoped_rule_id = None
|
unscoped_rule_id = None
|
||||||
if "rule_id" in rule:
|
if "rule_id" in rule:
|
||||||
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
|
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
|
||||||
|
@ -82,6 +85,10 @@ def _rule_to_template(rule):
|
||||||
return None
|
return None
|
||||||
templaterule = {"actions": rule["actions"]}
|
templaterule = {"actions": rule["actions"]}
|
||||||
templaterule["pattern"] = thecond["pattern"]
|
templaterule["pattern"] = thecond["pattern"]
|
||||||
|
else:
|
||||||
|
# This should not be reached unless this function is not kept in sync
|
||||||
|
# with PRIORITY_CLASS_INVERSE_MAP.
|
||||||
|
raise ValueError("Unexpected template_name: %s" % (template_name,))
|
||||||
|
|
||||||
if unscoped_rule_id:
|
if unscoped_rule_id:
|
||||||
templaterule["rule_id"] = unscoped_rule_id
|
templaterule["rule_id"] = unscoped_rule_id
|
||||||
|
@ -90,9 +97,9 @@ def _rule_to_template(rule):
|
||||||
return templaterule
|
return templaterule
|
||||||
|
|
||||||
|
|
||||||
def _rule_id_from_namespaced(in_rule_id):
|
def _rule_id_from_namespaced(in_rule_id: str) -> str:
|
||||||
return in_rule_id.split("/")[-1]
|
return in_rule_id.split("/")[-1]
|
||||||
|
|
||||||
|
|
||||||
def _priority_class_to_template_name(pc):
|
def _priority_class_to_template_name(pc: int) -> str:
|
||||||
return PRIORITY_CLASS_INVERSE_MAP[pc]
|
return PRIORITY_CLASS_INVERSE_MAP[pc]
|
||||||
|
|
|
@ -15,8 +15,14 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Dict, Iterable, Optional
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.types import StateMap
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
|
||||||
|
|
||||||
|
|
||||||
async def calculate_room_name(
|
async def calculate_room_name(
|
||||||
store,
|
store: "DataStore",
|
||||||
room_state_ids,
|
room_state_ids: StateMap[str],
|
||||||
user_id,
|
user_id: str,
|
||||||
fallback_to_members=True,
|
fallback_to_members: bool = True,
|
||||||
fallback_to_single_member=True,
|
fallback_to_single_member: bool = True,
|
||||||
):
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Works out a user-facing name for the given room as per Matrix
|
Works out a user-facing name for the given room as per Matrix
|
||||||
spec recommendations.
|
spec recommendations.
|
||||||
Does not yet support internationalisation.
|
Does not yet support internationalisation.
|
||||||
Args:
|
Args:
|
||||||
room_state: Dictionary of the room's state
|
store: The data store to query.
|
||||||
|
room_state_ids: Dictionary of the room's state IDs.
|
||||||
user_id: The ID of the user to whom the room name is being presented
|
user_id: The ID of the user to whom the room name is being presented
|
||||||
fallback_to_members: If False, return None instead of generating a name
|
fallback_to_members: If False, return None instead of generating a name
|
||||||
based on the room's members if the room has no
|
based on the room's members if the room has no
|
||||||
title or aliases.
|
title or aliases.
|
||||||
|
fallback_to_single_member: If False, return None instead of generating a
|
||||||
|
name based on the user who invited this user to the room if the room
|
||||||
|
has no title or aliases.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(string or None) A human readable name for the room.
|
A human readable name for the room, if possible.
|
||||||
"""
|
"""
|
||||||
# does it have a name?
|
# does it have a name?
|
||||||
if (EventTypes.Name, "") in room_state_ids:
|
if (EventTypes.Name, "") in room_state_ids:
|
||||||
|
@ -97,7 +107,7 @@ async def calculate_room_name(
|
||||||
name_from_member_event(inviter_member_event),
|
name_from_member_event(inviter_member_event),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return
|
return None
|
||||||
else:
|
else:
|
||||||
return "Room Invite"
|
return "Room Invite"
|
||||||
|
|
||||||
|
@ -150,19 +160,19 @@ async def calculate_room_name(
|
||||||
else:
|
else:
|
||||||
return ALL_ALONE
|
return ALL_ALONE
|
||||||
elif len(other_members) == 1 and not fallback_to_single_member:
|
elif len(other_members) == 1 and not fallback_to_single_member:
|
||||||
return
|
return None
|
||||||
else:
|
|
||||||
return descriptor_from_member_events(other_members)
|
return descriptor_from_member_events(other_members)
|
||||||
|
|
||||||
|
|
||||||
def descriptor_from_member_events(member_events):
|
def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
|
||||||
"""Get a description of the room based on the member events.
|
"""Get a description of the room based on the member events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
member_events (Iterable[FrozenEvent])
|
member_events: The events of a room.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str
|
The room description
|
||||||
"""
|
"""
|
||||||
|
|
||||||
member_events = list(member_events)
|
member_events = list(member_events)
|
||||||
|
@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def name_from_member_event(member_event):
|
def name_from_member_event(member_event: EventBase) -> str:
|
||||||
if (
|
if (
|
||||||
member_event.content
|
member_event.content
|
||||||
and "displayname" in member_event.content
|
and "displayname" in member_event.content
|
||||||
|
@ -193,12 +203,12 @@ def name_from_member_event(member_event):
|
||||||
return member_event.state_key
|
return member_event.state_key
|
||||||
|
|
||||||
|
|
||||||
def _state_as_two_level_dict(state):
|
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
|
||||||
ret = {}
|
ret = {} # type: Dict[str, Dict[str, str]]
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
ret.setdefault(k[0], {})[k[1]] = v
|
ret.setdefault(k[0], {})[k[1]] = v
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _looks_like_an_alias(string):
|
def _looks_like_an_alias(string: str) -> bool:
|
||||||
return ALIAS_RE.match(string) is not None
|
return ALIAS_RE.match(string) is not None
|
||||||
|
|
|
@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
|
||||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||||
|
|
||||||
|
|
||||||
def _room_member_count(ev, condition, room_member_count):
|
def _room_member_count(
|
||||||
|
ev: EventBase, condition: Dict[str, Any], room_member_count: int
|
||||||
|
) -> bool:
|
||||||
return _test_ineq_condition(condition, room_member_count)
|
return _test_ineq_condition(condition, room_member_count)
|
||||||
|
|
||||||
|
|
||||||
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
|
def _sender_notification_permission(
|
||||||
|
ev: EventBase,
|
||||||
|
condition: Dict[str, Any],
|
||||||
|
sender_power_level: int,
|
||||||
|
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||||
|
) -> bool:
|
||||||
notif_level_key = condition.get("key")
|
notif_level_key = condition.get("key")
|
||||||
if notif_level_key is None:
|
if notif_level_key is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
notif_levels = power_levels.get("notifications", {})
|
notif_levels = power_levels.get("notifications", {})
|
||||||
|
assert isinstance(notif_levels, dict)
|
||||||
room_notif_level = notif_levels.get(notif_level_key, 50)
|
room_notif_level = notif_levels.get(notif_level_key, 50)
|
||||||
|
|
||||||
return sender_power_level >= room_notif_level
|
return sender_power_level >= room_notif_level
|
||||||
|
|
||||||
|
|
||||||
def _test_ineq_condition(condition, number):
|
def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
|
||||||
if "is" not in condition:
|
if "is" not in condition:
|
||||||
return False
|
return False
|
||||||
m = INEQUALITY_EXPR.match(condition["is"])
|
m = INEQUALITY_EXPR.match(condition["is"])
|
||||||
|
@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
room_member_count: int,
|
room_member_count: int,
|
||||||
sender_power_level: int,
|
sender_power_level: int,
|
||||||
power_levels: dict,
|
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||||
):
|
):
|
||||||
self._event = event
|
self._event = event
|
||||||
self._room_member_count = room_member_count
|
self._room_member_count = room_member_count
|
||||||
|
@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
|
||||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||||
self._value_cache = _flatten_dict(event)
|
self._value_cache = _flatten_dict(event)
|
||||||
|
|
||||||
def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
|
def matches(
|
||||||
|
self, condition: Dict[str, Any], user_id: str, display_name: str
|
||||||
|
) -> bool:
|
||||||
if condition["kind"] == "event_match":
|
if condition["kind"] == "event_match":
|
||||||
return self._event_match(condition, user_id)
|
return self._event_match(condition, user_id)
|
||||||
elif condition["kind"] == "contains_display_name":
|
elif condition["kind"] == "contains_display_name":
|
||||||
|
@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
|
||||||
return r"(^|\W)%s(\W|$)" % (r,)
|
return r"(^|\W)%s(\W|$)" % (r,)
|
||||||
|
|
||||||
|
|
||||||
def _flatten_dict(d, prefix=[], result=None):
|
def _flatten_dict(
|
||||||
|
d: Union[EventBase, dict],
|
||||||
|
prefix: Optional[List[str]] = None,
|
||||||
|
result: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
if prefix is None:
|
||||||
|
prefix = []
|
||||||
if result is None:
|
if result is None:
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in d.items():
|
for key, value in d.items():
|
||||||
|
|
Loading…
Reference in New Issue