From 7e440520c9b370ce008c6a65c5dd87a360a6457c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 1 Oct 2021 07:02:32 -0400 Subject: [PATCH] Add type hints to filtering classes. (#10958) --- changelog.d/10958.misc | 1 + synapse/api/filtering.py | 117 +++++++++++++------- synapse/storage/databases/main/filtering.py | 8 +- 3 files changed, 81 insertions(+), 45 deletions(-) create mode 100644 changelog.d/10958.misc diff --git a/changelog.d/10958.misc b/changelog.d/10958.misc new file mode 100644 index 0000000000..409ecc35cb --- /dev/null +++ b/changelog.d/10958.misc @@ -0,0 +1 @@ +Add type hints to filtering classes. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index ad1ff6a9df..20e91a115d 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -15,7 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import List +from typing import ( + TYPE_CHECKING, + Awaitable, + Container, + Iterable, + List, + Optional, + Set, + TypeVar, + Union, +) import jsonschema from jsonschema import FormatChecker @@ -23,7 +33,11 @@ from jsonschema import FormatChecker from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.types import RoomID, UserID +from synapse.events import EventBase +from synapse.types import JsonDict, RoomID, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer FILTER_SCHEMA = { "additionalProperties": False, @@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = { @FormatChecker.cls_checks("matrix_room_id") -def matrix_room_id_validator(room_id_str): +def matrix_room_id_validator(room_id_str: str) -> RoomID: return RoomID.from_string(room_id_str) @FormatChecker.cls_checks("matrix_user_id") -def matrix_user_id_validator(user_id_str): +def matrix_user_id_validator(user_id_str: str) -> UserID: return UserID.from_string(user_id_str) class Filtering: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() - async def get_user_filter(self, user_localpart, filter_id): + async def get_user_filter( + self, user_localpart: str, filter_id: Union[int, str] + ) -> "FilterCollection": result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(result) - def add_user_filter(self, user_localpart, user_filter): + def add_user_filter( + self, user_localpart: str, user_filter: JsonDict + ) -> Awaitable[int]: self.check_valid_filter(user_filter) return self.store.add_user_filter(user_localpart, user_filter) @@ -146,13 +164,13 @@ class Filtering: # replace_user_filter at some point? There's no REST API specified for # them however - def check_valid_filter(self, user_filter_json): + def check_valid_filter(self, user_filter_json: JsonDict) -> None: """Check if the provided filter is valid. This inspects all definitions contained within the filter. Args: - user_filter_json(dict): The filter + user_filter_json: The filter Raises: SynapseError: If the filter is not valid. """ @@ -167,8 +185,12 @@ class Filtering: raise SynapseError(400, str(e)) +# Filters work across events, presence EDUs, and account data. +FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict) + + class FilterCollection: - def __init__(self, filter_json): + def __init__(self, filter_json: JsonDict): self._filter_json = filter_json room_filter_json = self._filter_json.get("room", {}) @@ -188,25 +210,25 @@ class FilterCollection: self.event_fields = filter_json.get("event_fields", []) self.event_format = filter_json.get("event_format", "client") - def __repr__(self): + def __repr__(self) -> str: return "" % (json.dumps(self._filter_json),) - def get_filter_json(self): + def get_filter_json(self) -> JsonDict: return self._filter_json - def timeline_limit(self): + def timeline_limit(self) -> int: return self._room_timeline_filter.limit() - def presence_limit(self): + def presence_limit(self) -> int: return self._presence_filter.limit() - def ephemeral_limit(self): + def ephemeral_limit(self) -> int: return self._room_ephemeral_filter.limit() - def lazy_load_members(self): + def lazy_load_members(self) -> bool: return self._room_state_filter.lazy_load_members() - def include_redundant_members(self): + def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members() def filter_presence(self, events): @@ -218,29 +240,31 @@ class FilterCollection: def filter_room_state(self, events): return self._room_state_filter.filter(self._room_filter.filter(events)) - def filter_room_timeline(self, events): + def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return self._room_timeline_filter.filter(self._room_filter.filter(events)) - def filter_room_ephemeral(self, events): + def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) - def filter_room_account_data(self, events): + def filter_room_account_data( + self, events: Iterable[FilterEvent] + ) -> List[FilterEvent]: return self._room_account_data.filter(self._room_filter.filter(events)) - def blocks_all_presence(self): + def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() or self._presence_filter.filters_all_senders() ) - def blocks_all_room_ephemeral(self): + def blocks_all_room_ephemeral(self) -> bool: return ( self._room_ephemeral_filter.filters_all_types() or self._room_ephemeral_filter.filters_all_senders() or self._room_ephemeral_filter.filters_all_rooms() ) - def blocks_all_room_timeline(self): + def blocks_all_room_timeline(self) -> bool: return ( self._room_timeline_filter.filters_all_types() or self._room_timeline_filter.filters_all_senders() @@ -249,7 +273,7 @@ class FilterCollection: class Filter: - def __init__(self, filter_json): + def __init__(self, filter_json: JsonDict): self.filter_json = filter_json self.types = self.filter_json.get("types", None) @@ -266,20 +290,20 @@ class Filter: self.labels = self.filter_json.get("org.matrix.labels", None) self.not_labels = self.filter_json.get("org.matrix.not_labels", []) - def filters_all_types(self): + def filters_all_types(self) -> bool: return "*" in self.not_types - def filters_all_senders(self): + def filters_all_senders(self) -> bool: return "*" in self.not_senders - def filters_all_rooms(self): + def filters_all_rooms(self) -> bool: return "*" in self.not_rooms - def check(self, event): + def check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. Returns: - bool: True if the event matches + True if the event matches """ # We usually get the full "events" as dictionaries coming through, # except for presence which actually gets passed around as its own @@ -305,18 +329,25 @@ class Filter: room_id = event.get("room_id", None) ev_type = event.get("type", None) - content = event.get("content", {}) + content = event.get("content") or {} # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), str) labels = content.get(EventContentFields.LABELS, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) - def check_fields(self, room_id, sender, event_type, labels, contains_url): + def check_fields( + self, + room_id: Optional[str], + sender: Optional[str], + event_type: Optional[str], + labels: Container[str], + contains_url: bool, + ) -> bool: """Checks whether the filter matches the given event fields. Returns: - bool: True if the event fields match + True if the event fields match """ literal_keys = { "rooms": lambda v: room_id == v, @@ -343,14 +374,14 @@ class Filter: return True - def filter_rooms(self, room_ids): + def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]: """Apply the 'rooms' filter to a given list of rooms. Args: - room_ids (list): A list of room_ids. + room_ids: A list of room_ids. Returns: - list: A list of room_ids that match the filter + A list of room_ids that match the filter """ room_ids = set(room_ids) @@ -363,23 +394,23 @@ class Filter: return room_ids - def filter(self, events): + def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: return list(filter(self.check, events)) - def limit(self): + def limit(self) -> int: return self.filter_json.get("limit", 10) - def lazy_load_members(self): + def lazy_load_members(self) -> bool: return self.filter_json.get("lazy_load_members", False) - def include_redundant_members(self): + def include_redundant_members(self) -> bool: return self.filter_json.get("include_redundant_members", False) - def with_room_ids(self, room_ids): + def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": """Returns a new filter with the given room IDs appended. Args: - room_ids (iterable[unicode]): The room_ids to add + room_ids: The room_ids to add Returns: filter: A new filter including the given rooms and the old @@ -390,8 +421,8 @@ class Filter: return newFilter -def _matches_wildcard(actual_value, filter_value): - if filter_value.endswith("*"): +def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool: + if filter_value.endswith("*") and isinstance(actual_value, str): type_prefix = filter_value[:-1] return actual_value.startswith(type_prefix) else: diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index bb244a03c0..434986fa64 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError @@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): @cached(num_args=2) - async def get_user_filter(self, user_localpart, filter_id): + async def get_user_filter( + self, user_localpart: str, filter_id: Union[int, str] + ) -> JsonDict: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then