Add type hints to filtering classes. (#10958)
This commit is contained in:
parent
9e5a429c8b
commit
7e440520c9
|
@ -0,0 +1 @@
|
|||
Add type hints to filtering classes.
|
|
@ -15,7 +15,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import List
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Container,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import jsonschema
|
||||
from jsonschema import FormatChecker
|
||||
|
@ -23,7 +33,11 @@ from jsonschema import FormatChecker
|
|||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.types import RoomID, UserID
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import JsonDict, RoomID, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
FILTER_SCHEMA = {
|
||||
"additionalProperties": False,
|
||||
|
@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = {
|
|||
|
||||
|
||||
@FormatChecker.cls_checks("matrix_room_id")
|
||||
def matrix_room_id_validator(room_id_str):
|
||||
def matrix_room_id_validator(room_id_str: str) -> RoomID:
|
||||
return RoomID.from_string(room_id_str)
|
||||
|
||||
|
||||
@FormatChecker.cls_checks("matrix_user_id")
|
||||
def matrix_user_id_validator(user_id_str):
|
||||
def matrix_user_id_validator(user_id_str: str) -> UserID:
|
||||
return UserID.from_string(user_id_str)
|
||||
|
||||
|
||||
class Filtering:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def get_user_filter(self, user_localpart, filter_id):
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
) -> "FilterCollection":
|
||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||
return FilterCollection(result)
|
||||
|
||||
def add_user_filter(self, user_localpart, user_filter):
|
||||
def add_user_filter(
|
||||
self, user_localpart: str, user_filter: JsonDict
|
||||
) -> Awaitable[int]:
|
||||
self.check_valid_filter(user_filter)
|
||||
return self.store.add_user_filter(user_localpart, user_filter)
|
||||
|
||||
|
@ -146,13 +164,13 @@ class Filtering:
|
|||
# replace_user_filter at some point? There's no REST API specified for
|
||||
# them however
|
||||
|
||||
def check_valid_filter(self, user_filter_json):
|
||||
def check_valid_filter(self, user_filter_json: JsonDict) -> None:
|
||||
"""Check if the provided filter is valid.
|
||||
|
||||
This inspects all definitions contained within the filter.
|
||||
|
||||
Args:
|
||||
user_filter_json(dict): The filter
|
||||
user_filter_json: The filter
|
||||
Raises:
|
||||
SynapseError: If the filter is not valid.
|
||||
"""
|
||||
|
@ -167,8 +185,12 @@ class Filtering:
|
|||
raise SynapseError(400, str(e))
|
||||
|
||||
|
||||
# Filters work across events, presence EDUs, and account data.
|
||||
FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
|
||||
|
||||
|
||||
class FilterCollection:
|
||||
def __init__(self, filter_json):
|
||||
def __init__(self, filter_json: JsonDict):
|
||||
self._filter_json = filter_json
|
||||
|
||||
room_filter_json = self._filter_json.get("room", {})
|
||||
|
@ -188,25 +210,25 @@ class FilterCollection:
|
|||
self.event_fields = filter_json.get("event_fields", [])
|
||||
self.event_format = filter_json.get("event_format", "client")
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
||||
|
||||
def get_filter_json(self):
|
||||
def get_filter_json(self) -> JsonDict:
|
||||
return self._filter_json
|
||||
|
||||
def timeline_limit(self):
|
||||
def timeline_limit(self) -> int:
|
||||
return self._room_timeline_filter.limit()
|
||||
|
||||
def presence_limit(self):
|
||||
def presence_limit(self) -> int:
|
||||
return self._presence_filter.limit()
|
||||
|
||||
def ephemeral_limit(self):
|
||||
def ephemeral_limit(self) -> int:
|
||||
return self._room_ephemeral_filter.limit()
|
||||
|
||||
def lazy_load_members(self):
|
||||
def lazy_load_members(self) -> bool:
|
||||
return self._room_state_filter.lazy_load_members()
|
||||
|
||||
def include_redundant_members(self):
|
||||
def include_redundant_members(self) -> bool:
|
||||
return self._room_state_filter.include_redundant_members()
|
||||
|
||||
def filter_presence(self, events):
|
||||
|
@ -218,29 +240,31 @@ class FilterCollection:
|
|||
def filter_room_state(self, events):
|
||||
return self._room_state_filter.filter(self._room_filter.filter(events))
|
||||
|
||||
def filter_room_timeline(self, events):
|
||||
def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||
return self._room_timeline_filter.filter(self._room_filter.filter(events))
|
||||
|
||||
def filter_room_ephemeral(self, events):
|
||||
def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
|
||||
|
||||
def filter_room_account_data(self, events):
|
||||
def filter_room_account_data(
|
||||
self, events: Iterable[FilterEvent]
|
||||
) -> List[FilterEvent]:
|
||||
return self._room_account_data.filter(self._room_filter.filter(events))
|
||||
|
||||
def blocks_all_presence(self):
|
||||
def blocks_all_presence(self) -> bool:
|
||||
return (
|
||||
self._presence_filter.filters_all_types()
|
||||
or self._presence_filter.filters_all_senders()
|
||||
)
|
||||
|
||||
def blocks_all_room_ephemeral(self):
|
||||
def blocks_all_room_ephemeral(self) -> bool:
|
||||
return (
|
||||
self._room_ephemeral_filter.filters_all_types()
|
||||
or self._room_ephemeral_filter.filters_all_senders()
|
||||
or self._room_ephemeral_filter.filters_all_rooms()
|
||||
)
|
||||
|
||||
def blocks_all_room_timeline(self):
|
||||
def blocks_all_room_timeline(self) -> bool:
|
||||
return (
|
||||
self._room_timeline_filter.filters_all_types()
|
||||
or self._room_timeline_filter.filters_all_senders()
|
||||
|
@ -249,7 +273,7 @@ class FilterCollection:
|
|||
|
||||
|
||||
class Filter:
|
||||
def __init__(self, filter_json):
|
||||
def __init__(self, filter_json: JsonDict):
|
||||
self.filter_json = filter_json
|
||||
|
||||
self.types = self.filter_json.get("types", None)
|
||||
|
@ -266,20 +290,20 @@ class Filter:
|
|||
self.labels = self.filter_json.get("org.matrix.labels", None)
|
||||
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
|
||||
|
||||
def filters_all_types(self):
|
||||
def filters_all_types(self) -> bool:
|
||||
return "*" in self.not_types
|
||||
|
||||
def filters_all_senders(self):
|
||||
def filters_all_senders(self) -> bool:
|
||||
return "*" in self.not_senders
|
||||
|
||||
def filters_all_rooms(self):
|
||||
def filters_all_rooms(self) -> bool:
|
||||
return "*" in self.not_rooms
|
||||
|
||||
def check(self, event):
|
||||
def check(self, event: FilterEvent) -> bool:
|
||||
"""Checks whether the filter matches the given event.
|
||||
|
||||
Returns:
|
||||
bool: True if the event matches
|
||||
True if the event matches
|
||||
"""
|
||||
# We usually get the full "events" as dictionaries coming through,
|
||||
# except for presence which actually gets passed around as its own
|
||||
|
@ -305,18 +329,25 @@ class Filter:
|
|||
room_id = event.get("room_id", None)
|
||||
ev_type = event.get("type", None)
|
||||
|
||||
content = event.get("content", {})
|
||||
content = event.get("content") or {}
|
||||
# check if there is a string url field in the content for filtering purposes
|
||||
contains_url = isinstance(content.get("url"), str)
|
||||
labels = content.get(EventContentFields.LABELS, [])
|
||||
|
||||
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
|
||||
|
||||
def check_fields(self, room_id, sender, event_type, labels, contains_url):
|
||||
def check_fields(
|
||||
self,
|
||||
room_id: Optional[str],
|
||||
sender: Optional[str],
|
||||
event_type: Optional[str],
|
||||
labels: Container[str],
|
||||
contains_url: bool,
|
||||
) -> bool:
|
||||
"""Checks whether the filter matches the given event fields.
|
||||
|
||||
Returns:
|
||||
bool: True if the event fields match
|
||||
True if the event fields match
|
||||
"""
|
||||
literal_keys = {
|
||||
"rooms": lambda v: room_id == v,
|
||||
|
@ -343,14 +374,14 @@ class Filter:
|
|||
|
||||
return True
|
||||
|
||||
def filter_rooms(self, room_ids):
|
||||
def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
|
||||
"""Apply the 'rooms' filter to a given list of rooms.
|
||||
|
||||
Args:
|
||||
room_ids (list): A list of room_ids.
|
||||
room_ids: A list of room_ids.
|
||||
|
||||
Returns:
|
||||
list: A list of room_ids that match the filter
|
||||
A list of room_ids that match the filter
|
||||
"""
|
||||
room_ids = set(room_ids)
|
||||
|
||||
|
@ -363,23 +394,23 @@ class Filter:
|
|||
|
||||
return room_ids
|
||||
|
||||
def filter(self, events):
|
||||
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||
return list(filter(self.check, events))
|
||||
|
||||
def limit(self):
|
||||
def limit(self) -> int:
|
||||
return self.filter_json.get("limit", 10)
|
||||
|
||||
def lazy_load_members(self):
|
||||
def lazy_load_members(self) -> bool:
|
||||
return self.filter_json.get("lazy_load_members", False)
|
||||
|
||||
def include_redundant_members(self):
|
||||
def include_redundant_members(self) -> bool:
|
||||
return self.filter_json.get("include_redundant_members", False)
|
||||
|
||||
def with_room_ids(self, room_ids):
|
||||
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
|
||||
"""Returns a new filter with the given room IDs appended.
|
||||
|
||||
Args:
|
||||
room_ids (iterable[unicode]): The room_ids to add
|
||||
room_ids: The room_ids to add
|
||||
|
||||
Returns:
|
||||
filter: A new filter including the given rooms and the old
|
||||
|
@ -390,8 +421,8 @@ class Filter:
|
|||
return newFilter
|
||||
|
||||
|
||||
def _matches_wildcard(actual_value, filter_value):
|
||||
if filter_value.endswith("*"):
|
||||
def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
|
||||
if filter_value.endswith("*") and isinstance(actual_value, str):
|
||||
type_prefix = filter_value[:-1]
|
||||
return actual_value.startswith(type_prefix)
|
||||
else:
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached
|
|||
|
||||
class FilteringStore(SQLBaseStore):
|
||||
@cached(num_args=2)
|
||||
async def get_user_filter(self, user_localpart, filter_id):
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
) -> JsonDict:
|
||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||
try:
|
||||
|
@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore):
|
|||
|
||||
return db_to_json(def_json)
|
||||
|
||||
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
|
||||
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
|
||||
def_json = encode_canonical_json(user_filter)
|
||||
|
||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||
|
|
Loading…
Reference in New Issue