Support filtering by relations per MSC3440 (#11236)
Adds experimental support for `relation_types` and `relation_senders` fields for filters.
This commit is contained in:
parent
4b3e30c276
commit
a19d01c3d9
|
@ -0,0 +1 @@
|
||||||
|
Support filtering by relation senders & types per [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
|
|
@ -1,7 +1,7 @@
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
# Copyright 2017 Vector Creations Ltd
|
# Copyright 2017 Vector Creations Ltd
|
||||||
# Copyright 2018-2019 New Vector Ltd
|
# Copyright 2018-2019 New Vector Ltd
|
||||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -86,6 +86,9 @@ ROOM_EVENT_FILTER_SCHEMA = {
|
||||||
# cf https://github.com/matrix-org/matrix-doc/pull/2326
|
# cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||||
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
|
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
|
||||||
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
|
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
|
||||||
|
# MSC3440, filtering by event relations.
|
||||||
|
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
|
||||||
|
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
|
||||||
|
|
||||||
class Filtering:
|
class Filtering:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
self._hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
||||||
|
|
||||||
async def get_user_filter(
|
async def get_user_filter(
|
||||||
self, user_localpart: str, filter_id: Union[int, str]
|
self, user_localpart: str, filter_id: Union[int, str]
|
||||||
) -> "FilterCollection":
|
) -> "FilterCollection":
|
||||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||||
return FilterCollection(result)
|
return FilterCollection(self._hs, result)
|
||||||
|
|
||||||
def add_user_filter(
|
def add_user_filter(
|
||||||
self, user_localpart: str, user_filter: JsonDict
|
self, user_localpart: str, user_filter: JsonDict
|
||||||
|
@ -191,21 +196,22 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
|
||||||
|
|
||||||
|
|
||||||
class FilterCollection:
|
class FilterCollection:
|
||||||
def __init__(self, filter_json: JsonDict):
|
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
|
||||||
self._filter_json = filter_json
|
self._filter_json = filter_json
|
||||||
|
|
||||||
room_filter_json = self._filter_json.get("room", {})
|
room_filter_json = self._filter_json.get("room", {})
|
||||||
|
|
||||||
self._room_filter = Filter(
|
self._room_filter = Filter(
|
||||||
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
|
hs,
|
||||||
|
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")},
|
||||||
)
|
)
|
||||||
|
|
||||||
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
|
self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
|
||||||
self._room_state_filter = Filter(room_filter_json.get("state", {}))
|
self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
|
||||||
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
|
self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
|
||||||
self._room_account_data = Filter(room_filter_json.get("account_data", {}))
|
self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
|
||||||
self._presence_filter = Filter(filter_json.get("presence", {}))
|
self._presence_filter = Filter(hs, filter_json.get("presence", {}))
|
||||||
self._account_data = Filter(filter_json.get("account_data", {}))
|
self._account_data = Filter(hs, filter_json.get("account_data", {}))
|
||||||
|
|
||||||
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
|
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
|
||||||
self.event_fields = filter_json.get("event_fields", [])
|
self.event_fields = filter_json.get("event_fields", [])
|
||||||
|
@ -232,25 +238,37 @@ class FilterCollection:
|
||||||
def include_redundant_members(self) -> bool:
|
def include_redundant_members(self) -> bool:
|
||||||
return self._room_state_filter.include_redundant_members
|
return self._room_state_filter.include_redundant_members
|
||||||
|
|
||||||
def filter_presence(
|
async def filter_presence(
|
||||||
self, events: Iterable[UserPresenceState]
|
self, events: Iterable[UserPresenceState]
|
||||||
) -> List[UserPresenceState]:
|
) -> List[UserPresenceState]:
|
||||||
return self._presence_filter.filter(events)
|
return await self._presence_filter.filter(events)
|
||||||
|
|
||||||
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||||
return self._account_data.filter(events)
|
return await self._account_data.filter(events)
|
||||||
|
|
||||||
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
|
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
|
||||||
return self._room_state_filter.filter(self._room_filter.filter(events))
|
return await self._room_state_filter.filter(
|
||||||
|
await self._room_filter.filter(events)
|
||||||
|
)
|
||||||
|
|
||||||
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
|
async def filter_room_timeline(
|
||||||
return self._room_timeline_filter.filter(self._room_filter.filter(events))
|
self, events: Iterable[EventBase]
|
||||||
|
) -> List[EventBase]:
|
||||||
|
return await self._room_timeline_filter.filter(
|
||||||
|
await self._room_filter.filter(events)
|
||||||
|
)
|
||||||
|
|
||||||
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||||
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
|
return await self._room_ephemeral_filter.filter(
|
||||||
|
await self._room_filter.filter(events)
|
||||||
|
)
|
||||||
|
|
||||||
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
async def filter_room_account_data(
|
||||||
return self._room_account_data.filter(self._room_filter.filter(events))
|
self, events: Iterable[JsonDict]
|
||||||
|
) -> List[JsonDict]:
|
||||||
|
return await self._room_account_data.filter(
|
||||||
|
await self._room_filter.filter(events)
|
||||||
|
)
|
||||||
|
|
||||||
def blocks_all_presence(self) -> bool:
|
def blocks_all_presence(self) -> bool:
|
||||||
return (
|
return (
|
||||||
|
@ -274,7 +292,9 @@ class FilterCollection:
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
def __init__(self, filter_json: JsonDict):
|
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
|
||||||
|
self._hs = hs
|
||||||
|
self._store = hs.get_datastore()
|
||||||
self.filter_json = filter_json
|
self.filter_json = filter_json
|
||||||
|
|
||||||
self.limit = filter_json.get("limit", 10)
|
self.limit = filter_json.get("limit", 10)
|
||||||
|
@ -297,6 +317,20 @@ class Filter:
|
||||||
self.labels = filter_json.get("org.matrix.labels", None)
|
self.labels = filter_json.get("org.matrix.labels", None)
|
||||||
self.not_labels = filter_json.get("org.matrix.not_labels", [])
|
self.not_labels = filter_json.get("org.matrix.not_labels", [])
|
||||||
|
|
||||||
|
# Ideally these would be rejected at the endpoint if they were provided
|
||||||
|
# and not supported, but that would involve modifying the JSON schema
|
||||||
|
# based on the homeserver configuration.
|
||||||
|
if hs.config.experimental.msc3440_enabled:
|
||||||
|
self.relation_senders = self.filter_json.get(
|
||||||
|
"io.element.relation_senders", None
|
||||||
|
)
|
||||||
|
self.relation_types = self.filter_json.get(
|
||||||
|
"io.element.relation_types", None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.relation_senders = None
|
||||||
|
self.relation_types = None
|
||||||
|
|
||||||
def filters_all_types(self) -> bool:
|
def filters_all_types(self) -> bool:
|
||||||
return "*" in self.not_types
|
return "*" in self.not_types
|
||||||
|
|
||||||
|
@ -306,7 +340,7 @@ class Filter:
|
||||||
def filters_all_rooms(self) -> bool:
|
def filters_all_rooms(self) -> bool:
|
||||||
return "*" in self.not_rooms
|
return "*" in self.not_rooms
|
||||||
|
|
||||||
def check(self, event: FilterEvent) -> bool:
|
def _check(self, event: FilterEvent) -> bool:
|
||||||
"""Checks whether the filter matches the given event.
|
"""Checks whether the filter matches the given event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -420,8 +454,30 @@ class Filter:
|
||||||
|
|
||||||
return room_ids
|
return room_ids
|
||||||
|
|
||||||
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
async def _check_event_relations(
|
||||||
return list(filter(self.check, events))
|
self, events: Iterable[FilterEvent]
|
||||||
|
) -> List[FilterEvent]:
|
||||||
|
# The event IDs to check, mypy doesn't understand the ifinstance check.
|
||||||
|
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
|
||||||
|
event_ids_to_keep = set(
|
||||||
|
await self._store.events_have_relations(
|
||||||
|
event_ids, self.relation_senders, self.relation_types
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
event
|
||||||
|
for event in events
|
||||||
|
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
|
||||||
|
]
|
||||||
|
|
||||||
|
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||||
|
result = [event for event in events if self._check(event)]
|
||||||
|
|
||||||
|
if self.relation_senders or self.relation_types:
|
||||||
|
return await self._check_event_relations(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
|
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
|
||||||
"""Returns a new filter with the given room IDs appended.
|
"""Returns a new filter with the given room IDs appended.
|
||||||
|
@ -433,7 +489,7 @@ class Filter:
|
||||||
filter: A new filter including the given rooms and the old
|
filter: A new filter including the given rooms and the old
|
||||||
filter's rooms.
|
filter's rooms.
|
||||||
"""
|
"""
|
||||||
newFilter = Filter(self.filter_json)
|
newFilter = Filter(self._hs, self.filter_json)
|
||||||
newFilter.rooms += room_ids
|
newFilter.rooms += room_ids
|
||||||
return newFilter
|
return newFilter
|
||||||
|
|
||||||
|
@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
|
||||||
return actual_value.startswith(type_prefix)
|
return actual_value.startswith(type_prefix)
|
||||||
else:
|
else:
|
||||||
return actual_value == filter_value
|
return actual_value == filter_value
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_FILTER_COLLECTION = FilterCollection({})
|
|
||||||
|
|
|
@ -424,7 +424,7 @@ class PaginationHandler:
|
||||||
|
|
||||||
if events:
|
if events:
|
||||||
if event_filter:
|
if event_filter:
|
||||||
events = event_filter.filter(events)
|
events = await event_filter.filter(events)
|
||||||
|
|
||||||
events = await filter_events_for_client(
|
events = await filter_events_for_client(
|
||||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
||||||
|
|
|
@ -1158,8 +1158,10 @@ class RoomContextHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if event_filter:
|
if event_filter:
|
||||||
results["events_before"] = event_filter.filter(results["events_before"])
|
results["events_before"] = await event_filter.filter(
|
||||||
results["events_after"] = event_filter.filter(results["events_after"])
|
results["events_before"]
|
||||||
|
)
|
||||||
|
results["events_after"] = await event_filter.filter(results["events_after"])
|
||||||
|
|
||||||
results["events_before"] = await filter_evts(results["events_before"])
|
results["events_before"] = await filter_evts(results["events_before"])
|
||||||
results["events_after"] = await filter_evts(results["events_after"])
|
results["events_after"] = await filter_evts(results["events_after"])
|
||||||
|
@ -1195,7 +1197,7 @@ class RoomContextHandler:
|
||||||
|
|
||||||
state_events = list(state[last_event_id].values())
|
state_events = list(state[last_event_id].values())
|
||||||
if event_filter:
|
if event_filter:
|
||||||
state_events = event_filter.filter(state_events)
|
state_events = await event_filter.filter(state_events)
|
||||||
|
|
||||||
results["state"] = await filter_evts(state_events)
|
results["state"] = await filter_evts(state_events)
|
||||||
|
|
||||||
|
|
|
@ -180,7 +180,7 @@ class SearchHandler:
|
||||||
% (set(group_keys) - {"room_id", "sender"},),
|
% (set(group_keys) - {"room_id", "sender"},),
|
||||||
)
|
)
|
||||||
|
|
||||||
search_filter = Filter(filter_dict)
|
search_filter = Filter(self.hs, filter_dict)
|
||||||
|
|
||||||
# TODO: Search through left rooms too
|
# TODO: Search through left rooms too
|
||||||
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
|
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||||
|
@ -242,7 +242,7 @@ class SearchHandler:
|
||||||
|
|
||||||
rank_map.update({r["event"].event_id: r["rank"] for r in results})
|
rank_map.update({r["event"].event_id: r["rank"] for r in results})
|
||||||
|
|
||||||
filtered_events = search_filter.filter([r["event"] for r in results])
|
filtered_events = await search_filter.filter([r["event"] for r in results])
|
||||||
|
|
||||||
events = await filter_events_for_client(
|
events = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), filtered_events
|
self.storage, user.to_string(), filtered_events
|
||||||
|
@ -292,7 +292,9 @@ class SearchHandler:
|
||||||
|
|
||||||
rank_map.update({r["event"].event_id: r["rank"] for r in results})
|
rank_map.update({r["event"].event_id: r["rank"] for r in results})
|
||||||
|
|
||||||
filtered_events = search_filter.filter([r["event"] for r in results])
|
filtered_events = await search_filter.filter(
|
||||||
|
[r["event"] for r in results]
|
||||||
|
)
|
||||||
|
|
||||||
events = await filter_events_for_client(
|
events = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), filtered_events
|
self.storage, user.to_string(), filtered_events
|
||||||
|
|
|
@ -510,7 +510,7 @@ class SyncHandler:
|
||||||
log_kv({"limited": limited})
|
log_kv({"limited": limited})
|
||||||
|
|
||||||
if potential_recents:
|
if potential_recents:
|
||||||
recents = sync_config.filter_collection.filter_room_timeline(
|
recents = await sync_config.filter_collection.filter_room_timeline(
|
||||||
potential_recents
|
potential_recents
|
||||||
)
|
)
|
||||||
log_kv({"recents_after_sync_filtering": len(recents)})
|
log_kv({"recents_after_sync_filtering": len(recents)})
|
||||||
|
@ -575,8 +575,8 @@ class SyncHandler:
|
||||||
|
|
||||||
log_kv({"loaded_recents": len(events)})
|
log_kv({"loaded_recents": len(events)})
|
||||||
|
|
||||||
loaded_recents = sync_config.filter_collection.filter_room_timeline(
|
loaded_recents = (
|
||||||
events
|
await sync_config.filter_collection.filter_room_timeline(events)
|
||||||
)
|
)
|
||||||
|
|
||||||
log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
|
log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
|
||||||
|
@ -1015,7 +1015,7 @@ class SyncHandler:
|
||||||
|
|
||||||
return {
|
return {
|
||||||
(e.type, e.state_key): e
|
(e.type, e.state_key): e
|
||||||
for e in sync_config.filter_collection.filter_room_state(
|
for e in await sync_config.filter_collection.filter_room_state(
|
||||||
list(state.values())
|
list(state.values())
|
||||||
)
|
)
|
||||||
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
|
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
|
||||||
|
@ -1383,7 +1383,7 @@ class SyncHandler:
|
||||||
sync_config.user
|
sync_config.user
|
||||||
)
|
)
|
||||||
|
|
||||||
account_data_for_user = sync_config.filter_collection.filter_account_data(
|
account_data_for_user = await sync_config.filter_collection.filter_account_data(
|
||||||
[
|
[
|
||||||
{"type": account_data_type, "content": content}
|
{"type": account_data_type, "content": content}
|
||||||
for account_data_type, content in account_data.items()
|
for account_data_type, content in account_data.items()
|
||||||
|
@ -1448,7 +1448,7 @@ class SyncHandler:
|
||||||
# Deduplicate the presence entries so that there's at most one per user
|
# Deduplicate the presence entries so that there's at most one per user
|
||||||
presence = list({p.user_id: p for p in presence}.values())
|
presence = list({p.user_id: p for p in presence}.values())
|
||||||
|
|
||||||
presence = sync_config.filter_collection.filter_presence(presence)
|
presence = await sync_config.filter_collection.filter_presence(presence)
|
||||||
|
|
||||||
sync_result_builder.presence = presence
|
sync_result_builder.presence = presence
|
||||||
|
|
||||||
|
@ -2021,12 +2021,14 @@ class SyncHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
account_data_events = (
|
account_data_events = (
|
||||||
sync_config.filter_collection.filter_room_account_data(
|
await sync_config.filter_collection.filter_room_account_data(
|
||||||
account_data_events
|
account_data_events
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
|
ephemeral = await sync_config.filter_collection.filter_room_ephemeral(
|
||||||
|
ephemeral
|
||||||
|
)
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
always_include
|
always_include
|
||||||
|
|
|
@ -583,6 +583,7 @@ class RoomEventContextServlet(RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.room_context_handler = hs.get_room_context_handler()
|
self.room_context_handler = hs.get_room_context_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
@ -600,7 +601,9 @@ class RoomEventContextServlet(RestServlet):
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
if filter_str:
|
||||||
filter_json = urlparse.unquote(filter_str)
|
filter_json = urlparse.unquote(filter_str)
|
||||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
event_filter: Optional[Filter] = Filter(
|
||||||
|
self._hs, json_decoder.decode(filter_json)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
|
||||||
|
|
|
@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._hs = hs
|
||||||
self.pagination_handler = hs.get_pagination_handler()
|
self.pagination_handler = hs.get_pagination_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -567,7 +568,9 @@ class RoomMessageListRestServlet(RestServlet):
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
if filter_str:
|
||||||
filter_json = urlparse.unquote(filter_str)
|
filter_json = urlparse.unquote(filter_str)
|
||||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
event_filter: Optional[Filter] = Filter(
|
||||||
|
self._hs, json_decoder.decode(filter_json)
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
event_filter
|
event_filter
|
||||||
and event_filter.filter_json.get("event_format", "client")
|
and event_filter.filter_json.get("event_format", "client")
|
||||||
|
@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.room_context_handler = hs.get_room_context_handler()
|
self.room_context_handler = hs.get_room_context_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
@ -688,7 +692,9 @@ class RoomEventContextServlet(RestServlet):
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
if filter_str:
|
||||||
filter_json = urlparse.unquote(filter_str)
|
filter_json = urlparse.unquote(filter_str)
|
||||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
event_filter: Optional[Filter] = Filter(
|
||||||
|
self._hs, json_decoder.decode(filter_json)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ from typing import (
|
||||||
|
|
||||||
from synapse.api.constants import Membership, PresenceState
|
from synapse.api.constants import Membership, PresenceState
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
|
from synapse.api.filtering import FilterCollection
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import (
|
from synapse.events.utils import (
|
||||||
|
@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet):
|
||||||
request_key = (user, timeout, since, filter_id, full_state, device_id)
|
request_key = (user, timeout, since, filter_id, full_state, device_id)
|
||||||
|
|
||||||
if filter_id is None:
|
if filter_id is None:
|
||||||
filter_collection = DEFAULT_FILTER_COLLECTION
|
filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION
|
||||||
elif filter_id.startswith("{"):
|
elif filter_id.startswith("{"):
|
||||||
try:
|
try:
|
||||||
filter_object = json_decoder.decode(filter_id)
|
filter_object = json_decoder.decode(filter_id)
|
||||||
|
@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet):
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "Invalid filter JSON")
|
raise SynapseError(400, "Invalid filter JSON")
|
||||||
self.filtering.check_valid_filter(filter_object)
|
self.filtering.check_valid_filter(filter_object)
|
||||||
filter_collection = FilterCollection(filter_object)
|
filter_collection = FilterCollection(self.hs, filter_object)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
filter_collection = await self.filtering.get_user_filter(
|
filter_collection = await self.filtering.get_user_filter(
|
||||||
|
|
|
@ -20,7 +20,7 @@ import attr
|
||||||
from synapse.api.constants import RelationTypes
|
from synapse.api.constants import RelationTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
|
||||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||||
from synapse.storage.relations import (
|
from synapse.storage.relations import (
|
||||||
AggregationPaginationToken,
|
AggregationPaginationToken,
|
||||||
|
@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return count, latest_event
|
return count, latest_event
|
||||||
|
|
||||||
|
async def events_have_relations(
|
||||||
|
self,
|
||||||
|
parent_ids: List[str],
|
||||||
|
relation_senders: Optional[List[str]],
|
||||||
|
relation_types: Optional[List[str]],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Check which events have a relationship from the given senders of the
|
||||||
|
given types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent_ids: The events being annotated
|
||||||
|
relation_senders: The relation senders to check.
|
||||||
|
relation_types: The relation types to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the event has at least one relationship from one of the given senders of the given type.
|
||||||
|
"""
|
||||||
|
# If no restrictions are given then the event has the required relations.
|
||||||
|
if not relation_senders and not relation_types:
|
||||||
|
return parent_ids
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT relates_to_id FROM event_relations
|
||||||
|
INNER JOIN events USING (event_id)
|
||||||
|
WHERE
|
||||||
|
%s;
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_if_event_has_relations(txn) -> List[str]:
|
||||||
|
clauses: List[str] = []
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "relates_to_id", parent_ids
|
||||||
|
)
|
||||||
|
clauses.append(clause)
|
||||||
|
|
||||||
|
if relation_senders:
|
||||||
|
clause, temp_args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "sender", relation_senders
|
||||||
|
)
|
||||||
|
clauses.append(clause)
|
||||||
|
args.extend(temp_args)
|
||||||
|
if relation_types:
|
||||||
|
clause, temp_args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "relation_type", relation_types
|
||||||
|
)
|
||||||
|
clauses.append(clause)
|
||||||
|
args.extend(temp_args)
|
||||||
|
|
||||||
|
txn.execute(sql % " AND ".join(clauses), args)
|
||||||
|
|
||||||
|
return [row[0] for row in txn]
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_if_event_has_relations", _get_if_event_has_relations
|
||||||
|
)
|
||||||
|
|
||||||
async def has_user_annotated_event(
|
async def has_user_annotated_event(
|
||||||
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
|
@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
||||||
args = []
|
args = []
|
||||||
|
|
||||||
if event_filter.types:
|
if event_filter.types:
|
||||||
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
|
clauses.append(
|
||||||
|
"(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
|
||||||
|
)
|
||||||
args.extend(event_filter.types)
|
args.extend(event_filter.types)
|
||||||
|
|
||||||
for typ in event_filter.not_types:
|
for typ in event_filter.not_types:
|
||||||
clauses.append("type != ?")
|
clauses.append("event.type != ?")
|
||||||
args.append(typ)
|
args.append(typ)
|
||||||
|
|
||||||
if event_filter.senders:
|
if event_filter.senders:
|
||||||
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
|
clauses.append(
|
||||||
|
"(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
|
||||||
|
)
|
||||||
args.extend(event_filter.senders)
|
args.extend(event_filter.senders)
|
||||||
|
|
||||||
for sender in event_filter.not_senders:
|
for sender in event_filter.not_senders:
|
||||||
clauses.append("sender != ?")
|
clauses.append("event.sender != ?")
|
||||||
args.append(sender)
|
args.append(sender)
|
||||||
|
|
||||||
if event_filter.rooms:
|
if event_filter.rooms:
|
||||||
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
|
clauses.append(
|
||||||
|
"(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
|
||||||
|
)
|
||||||
args.extend(event_filter.rooms)
|
args.extend(event_filter.rooms)
|
||||||
|
|
||||||
for room_id in event_filter.not_rooms:
|
for room_id in event_filter.not_rooms:
|
||||||
clauses.append("room_id != ?")
|
clauses.append("event.room_id != ?")
|
||||||
args.append(room_id)
|
args.append(room_id)
|
||||||
|
|
||||||
if event_filter.contains_url:
|
if event_filter.contains_url:
|
||||||
clauses.append("contains_url = ?")
|
clauses.append("event.contains_url = ?")
|
||||||
args.append(event_filter.contains_url)
|
args.append(event_filter.contains_url)
|
||||||
|
|
||||||
# We're only applying the "labels" filter on the database query, because applying the
|
# We're only applying the "labels" filter on the database query, because applying the
|
||||||
|
@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
||||||
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
|
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
|
||||||
args.extend(event_filter.labels)
|
args.extend(event_filter.labels)
|
||||||
|
|
||||||
|
# Filter on relation_senders / relation types from the joined tables.
|
||||||
|
if event_filter.relation_senders:
|
||||||
|
clauses.append(
|
||||||
|
"(%s)"
|
||||||
|
% " OR ".join(
|
||||||
|
"related_event.sender = ?" for _ in event_filter.relation_senders
|
||||||
|
)
|
||||||
|
)
|
||||||
|
args.extend(event_filter.relation_senders)
|
||||||
|
|
||||||
|
if event_filter.relation_types:
|
||||||
|
clauses.append(
|
||||||
|
"(%s)"
|
||||||
|
% " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
|
||||||
|
)
|
||||||
|
args.extend(event_filter.relation_types)
|
||||||
|
|
||||||
return " AND ".join(clauses), args
|
return " AND ".join(clauses), args
|
||||||
|
|
||||||
|
|
||||||
|
@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
bounds = generate_pagination_where_clause(
|
bounds = generate_pagination_where_clause(
|
||||||
direction=direction,
|
direction=direction,
|
||||||
column_names=("topological_ordering", "stream_ordering"),
|
column_names=("event.topological_ordering", "event.stream_ordering"),
|
||||||
from_token=from_bound,
|
from_token=from_bound,
|
||||||
to_token=to_bound,
|
to_token=to_bound,
|
||||||
engine=self.database_engine,
|
engine=self.database_engine,
|
||||||
|
@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
select_keywords = "SELECT"
|
select_keywords = "SELECT"
|
||||||
join_clause = ""
|
join_clause = ""
|
||||||
|
# Using DISTINCT in this SELECT query is quite expensive, because it
|
||||||
|
# requires the engine to sort on the entire (not limited) result set,
|
||||||
|
# i.e. the entire events table. Only use it in scenarios that could result
|
||||||
|
# in the same event ID occurring multiple times in the results.
|
||||||
|
needs_distinct = False
|
||||||
if event_filter and event_filter.labels:
|
if event_filter and event_filter.labels:
|
||||||
# If we're not filtering on a label, then joining on event_labels will
|
# If we're not filtering on a label, then joining on event_labels will
|
||||||
# return as many row for a single event as the number of labels it has. To
|
# return as many row for a single event as the number of labels it has. To
|
||||||
# avoid this, only join if we're filtering on at least one label.
|
# avoid this, only join if we're filtering on at least one label.
|
||||||
join_clause = """
|
join_clause += """
|
||||||
LEFT JOIN event_labels
|
LEFT JOIN event_labels
|
||||||
USING (event_id, room_id, topological_ordering)
|
USING (event_id, room_id, topological_ordering)
|
||||||
"""
|
"""
|
||||||
if len(event_filter.labels) > 1:
|
if len(event_filter.labels) > 1:
|
||||||
# Using DISTINCT in this SELECT query is quite expensive, because it
|
# Multiple labels could cause the same event to appear multiple times.
|
||||||
# requires the engine to sort on the entire (not limited) result set,
|
needs_distinct = True
|
||||||
# i.e. the entire events table. We only need to use it when we're
|
|
||||||
# filtering on more than two labels, because that's the only scenario
|
# If there is a filter on relation_senders and relation_types join to the
|
||||||
# in which we can possibly to get multiple times the same event ID in
|
# relations table.
|
||||||
# the results.
|
if event_filter and (
|
||||||
select_keywords += "DISTINCT"
|
event_filter.relation_senders or event_filter.relation_types
|
||||||
|
):
|
||||||
|
# Filtering by relations could cause the same event to appear multiple
|
||||||
|
# times (since there's no limit on the number of relations to an event).
|
||||||
|
needs_distinct = True
|
||||||
|
join_clause += """
|
||||||
|
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
|
||||||
|
"""
|
||||||
|
if event_filter.relation_senders:
|
||||||
|
join_clause += """
|
||||||
|
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if needs_distinct:
|
||||||
|
select_keywords += " DISTINCT"
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
%(select_keywords)s
|
%(select_keywords)s
|
||||||
event_id, instance_name,
|
event.event_id, event.instance_name,
|
||||||
topological_ordering, stream_ordering
|
event.topological_ordering, event.stream_ordering
|
||||||
FROM events
|
FROM events AS event
|
||||||
%(join_clause)s
|
%(join_clause)s
|
||||||
WHERE outlier = ? AND room_id = ? AND %(bounds)s
|
WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
|
||||||
ORDER BY topological_ordering %(order)s,
|
ORDER BY event.topological_ordering %(order)s,
|
||||||
stream_ordering %(order)s LIMIT ?
|
event.stream_ordering %(order)s LIMIT ?
|
||||||
""" % {
|
""" % {
|
||||||
"select_keywords": select_keywords,
|
"select_keywords": select_keywords,
|
||||||
"join_clause": join_clause,
|
"join_clause": join_clause,
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields
|
from synapse.api.constants import EventContentFields
|
||||||
|
@ -51,9 +53,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
{"presence": {"senders": ["@bar;pik.test.com"]}},
|
{"presence": {"senders": ["@bar;pik.test.com"]}},
|
||||||
]
|
]
|
||||||
for filter in invalid_filters:
|
for filter in invalid_filters:
|
||||||
with self.assertRaises(SynapseError) as check_filter_error:
|
with self.assertRaises(SynapseError):
|
||||||
self.filtering.check_valid_filter(filter)
|
self.filtering.check_valid_filter(filter)
|
||||||
self.assertIsInstance(check_filter_error.exception, SynapseError)
|
|
||||||
|
|
||||||
def test_valid_filters(self):
|
def test_valid_filters(self):
|
||||||
valid_filters = [
|
valid_filters = [
|
||||||
|
@ -119,12 +120,12 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
|
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_types_works_with_wildcards(self):
|
def test_definition_types_works_with_wildcards(self):
|
||||||
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
|
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_types_works_with_unknowns(self):
|
def test_definition_types_works_with_unknowns(self):
|
||||||
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
|
@ -133,24 +134,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="now.for.something.completely.different",
|
type="now.for.something.completely.different",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_literals(self):
|
def test_definition_not_types_works_with_literals(self):
|
||||||
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
|
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_wildcards(self):
|
def test_definition_not_types_works_with_wildcards(self):
|
||||||
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
|
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_unknowns(self):
|
def test_definition_not_types_works_with_unknowns(self):
|
||||||
definition = {"not_types": ["m.*", "org.*"]}
|
definition = {"not_types": ["m.*", "org.*"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_takes_priority_over_types(self):
|
def test_definition_not_types_takes_priority_over_types(self):
|
||||||
definition = {
|
definition = {
|
||||||
|
@ -158,35 +159,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
"types": ["m.room.message", "m.room.topic"],
|
"types": ["m.room.message", "m.room.topic"],
|
||||||
}
|
}
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_senders_works_with_literals(self):
|
def test_definition_senders_works_with_literals(self):
|
||||||
definition = {"senders": ["@flibble:wibble"]}
|
definition = {"senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_senders_works_with_unknowns(self):
|
def test_definition_senders_works_with_unknowns(self):
|
||||||
definition = {"senders": ["@flibble:wibble"]}
|
definition = {"senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_works_with_literals(self):
|
def test_definition_not_senders_works_with_literals(self):
|
||||||
definition = {"not_senders": ["@flibble:wibble"]}
|
definition = {"not_senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_works_with_unknowns(self):
|
def test_definition_not_senders_works_with_unknowns(self):
|
||||||
definition = {"not_senders": ["@flibble:wibble"]}
|
definition = {"not_senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_takes_priority_over_senders(self):
|
def test_definition_not_senders_takes_priority_over_senders(self):
|
||||||
definition = {
|
definition = {
|
||||||
|
@ -196,14 +197,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
|
sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_rooms_works_with_literals(self):
|
def test_definition_rooms_works_with_literals(self):
|
||||||
definition = {"rooms": ["!secretbase:unknown"]}
|
definition = {"rooms": ["!secretbase:unknown"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_rooms_works_with_unknowns(self):
|
def test_definition_rooms_works_with_unknowns(self):
|
||||||
definition = {"rooms": ["!secretbase:unknown"]}
|
definition = {"rooms": ["!secretbase:unknown"]}
|
||||||
|
@ -212,7 +213,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!anothersecretbase:unknown",
|
room_id="!anothersecretbase:unknown",
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_works_with_literals(self):
|
def test_definition_not_rooms_works_with_literals(self):
|
||||||
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
|
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
|
||||||
|
@ -221,7 +222,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!anothersecretbase:unknown",
|
room_id="!anothersecretbase:unknown",
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_works_with_unknowns(self):
|
def test_definition_not_rooms_works_with_unknowns(self):
|
||||||
definition = {"not_rooms": ["!secretbase:unknown"]}
|
definition = {"not_rooms": ["!secretbase:unknown"]}
|
||||||
|
@ -230,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.room.message",
|
type="m.room.message",
|
||||||
room_id="!anothersecretbase:unknown",
|
room_id="!anothersecretbase:unknown",
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_takes_priority_over_rooms(self):
|
def test_definition_not_rooms_takes_priority_over_rooms(self):
|
||||||
definition = {
|
definition = {
|
||||||
|
@ -240,7 +241,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event(self):
|
def test_definition_combined_event(self):
|
||||||
definition = {
|
definition = {
|
||||||
|
@ -256,7 +257,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.room.message", # yup
|
type="m.room.message", # yup
|
||||||
room_id="!stage:unknown", # yup
|
room_id="!stage:unknown", # yup
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_sender(self):
|
def test_definition_combined_event_bad_sender(self):
|
||||||
definition = {
|
definition = {
|
||||||
|
@ -272,7 +273,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.room.message", # yup
|
type="m.room.message", # yup
|
||||||
room_id="!stage:unknown", # yup
|
room_id="!stage:unknown", # yup
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_room(self):
|
def test_definition_combined_event_bad_room(self):
|
||||||
definition = {
|
definition = {
|
||||||
|
@ -288,7 +289,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.room.message", # yup
|
type="m.room.message", # yup
|
||||||
room_id="!piggyshouse:muppets", # nope
|
room_id="!piggyshouse:muppets", # nope
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_type(self):
|
def test_definition_combined_event_bad_type(self):
|
||||||
definition = {
|
definition = {
|
||||||
|
@ -304,7 +305,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="muppets.misspiggy.kisses", # nope
|
type="muppets.misspiggy.kisses", # nope
|
||||||
room_id="!stage:unknown", # yup
|
room_id="!stage:unknown", # yup
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_filter_labels(self):
|
def test_filter_labels(self):
|
||||||
definition = {"org.matrix.labels": ["#fun"]}
|
definition = {"org.matrix.labels": ["#fun"]}
|
||||||
|
@ -315,7 +316,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
content={EventContentFields.LABELS: ["#fun"]},
|
content={EventContentFields.LABELS: ["#fun"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -324,7 +325,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
content={EventContentFields.LABELS: ["#notfun"]},
|
content={EventContentFields.LABELS: ["#notfun"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_filter_not_labels(self):
|
def test_filter_not_labels(self):
|
||||||
definition = {"org.matrix.not_labels": ["#fun"]}
|
definition = {"org.matrix.not_labels": ["#fun"]}
|
||||||
|
@ -335,7 +336,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
content={EventContentFields.LABELS: ["#fun"]},
|
content={EventContentFields.LABELS: ["#fun"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertFalse(Filter(definition).check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -344,7 +345,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
content={EventContentFields.LABELS: ["#notfun"]},
|
content={EventContentFields.LABELS: ["#notfun"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(Filter(definition).check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_filter_presence_match(self):
|
def test_filter_presence_match(self):
|
||||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||||
|
@ -362,7 +363,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_presence(events=events)
|
results = self.get_success(user_filter.filter_presence(events=events))
|
||||||
self.assertEquals(events, results)
|
self.assertEquals(events, results)
|
||||||
|
|
||||||
def test_filter_presence_no_match(self):
|
def test_filter_presence_no_match(self):
|
||||||
|
@ -386,7 +387,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_presence(events=events)
|
results = self.get_success(user_filter.filter_presence(events=events))
|
||||||
self.assertEquals([], results)
|
self.assertEquals([], results)
|
||||||
|
|
||||||
def test_filter_room_state_match(self):
|
def test_filter_room_state_match(self):
|
||||||
|
@ -405,7 +406,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_room_state(events=events)
|
results = self.get_success(user_filter.filter_room_state(events=events))
|
||||||
self.assertEquals(events, results)
|
self.assertEquals(events, results)
|
||||||
|
|
||||||
def test_filter_room_state_no_match(self):
|
def test_filter_room_state_no_match(self):
|
||||||
|
@ -426,7 +427,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_room_state(events)
|
results = self.get_success(user_filter.filter_room_state(events))
|
||||||
self.assertEquals([], results)
|
self.assertEquals([], results)
|
||||||
|
|
||||||
def test_filter_rooms(self):
|
def test_filter_rooms(self):
|
||||||
|
@ -441,10 +442,52 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
"!not_included:example.com", # Disallowed because not in rooms.
|
"!not_included:example.com", # Disallowed because not in rooms.
|
||||||
]
|
]
|
||||||
|
|
||||||
filtered_room_ids = list(Filter(definition).filter_rooms(room_ids))
|
filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids))
|
||||||
|
|
||||||
self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
|
self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
|
||||||
|
|
||||||
|
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
|
||||||
|
def test_filter_relations(self):
|
||||||
|
events = [
|
||||||
|
# An event without a relation.
|
||||||
|
MockEvent(
|
||||||
|
event_id="$no_relation",
|
||||||
|
sender="@foo:bar",
|
||||||
|
type="org.matrix.custom.event",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
),
|
||||||
|
# An event with a relation.
|
||||||
|
MockEvent(
|
||||||
|
event_id="$with_relation",
|
||||||
|
sender="@foo:bar",
|
||||||
|
type="org.matrix.custom.event",
|
||||||
|
room_id="!foo:bar",
|
||||||
|
),
|
||||||
|
# Non-EventBase objects get passed through.
|
||||||
|
{},
|
||||||
|
]
|
||||||
|
|
||||||
|
# For the following tests we patch the datastore method (intead of injecting
|
||||||
|
# events). This is a bit cheeky, but tests the logic of _check_event_relations.
|
||||||
|
|
||||||
|
# Filter for a particular sender.
|
||||||
|
definition = {
|
||||||
|
"io.element.relation_senders": ["@foo:bar"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def events_have_relations(*args, **kwargs):
|
||||||
|
return ["$with_relation"]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
self.datastore, "events_have_relations", new=events_have_relations
|
||||||
|
):
|
||||||
|
filtered_events = list(
|
||||||
|
self.get_success(
|
||||||
|
Filter(self.hs, definition)._check_event_relations(events)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEquals(filtered_events, events[1:])
|
||||||
|
|
||||||
def test_add_filter(self):
|
def test_add_filter(self):
|
||||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, JoinRules
|
from synapse.api.constants import EventTypes, JoinRules
|
||||||
from synapse.api.errors import Codes, ResourceLimitError
|
from synapse.api.errors import Codes, ResourceLimitError
|
||||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
|
from synapse.api.filtering import Filtering
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.handlers.sync import SyncConfig
|
from synapse.handlers.sync import SyncConfig
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
|
@ -197,7 +198,7 @@ def generate_sync_config(
|
||||||
_request_key += 1
|
_request_key += 1
|
||||||
return SyncConfig(
|
return SyncConfig(
|
||||||
user=UserID.from_string(user_id),
|
user=UserID.from_string(user_id),
|
||||||
filter_collection=DEFAULT_FILTER_COLLECTION,
|
filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
|
||||||
is_guest=False,
|
is_guest=False,
|
||||||
request_key=("request_key", _request_key),
|
request_key=("request_key", _request_key),
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
|
|
|
@ -25,7 +25,12 @@ from urllib import parse as urlparse
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
from synapse.api.constants import (
|
||||||
|
EventContentFields,
|
||||||
|
EventTypes,
|
||||||
|
Membership,
|
||||||
|
RelationTypes,
|
||||||
|
)
|
||||||
from synapse.api.errors import Codes, HttpResponseException
|
from synapse.api.errors import Codes, HttpResponseException
|
||||||
from synapse.handlers.pagination import PurgeStatus
|
from synapse.handlers.pagination import PurgeStatus
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
|
@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
|
|
||||||
|
class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
|
room.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
|
config["experimental_features"] = {"msc3440_enabled": True}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, homeserver):
|
||||||
|
self.user_id = self.register_user("test", "test")
|
||||||
|
self.tok = self.login("test", "test")
|
||||||
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||||
|
|
||||||
|
self.second_user_id = self.register_user("second", "test")
|
||||||
|
self.second_tok = self.login("second", "test")
|
||||||
|
self.helper.join(
|
||||||
|
room=self.room_id, user=self.second_user_id, tok=self.second_tok
|
||||||
|
)
|
||||||
|
|
||||||
|
self.third_user_id = self.register_user("third", "test")
|
||||||
|
self.third_tok = self.login("third", "test")
|
||||||
|
self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
|
||||||
|
|
||||||
|
# An initial event with a relation from second user.
|
||||||
|
res = self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "Message 1"},
|
||||||
|
tok=self.tok,
|
||||||
|
)
|
||||||
|
self.event_id_1 = res["event_id"]
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.reaction",
|
||||||
|
content={
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": RelationTypes.ANNOTATION,
|
||||||
|
"event_id": self.event_id_1,
|
||||||
|
"key": "👍",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tok=self.second_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Another event with a relation from third user.
|
||||||
|
res = self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "Message 2"},
|
||||||
|
tok=self.tok,
|
||||||
|
)
|
||||||
|
self.event_id_2 = res["event_id"]
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.reaction",
|
||||||
|
content={
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": RelationTypes.REFERENCE,
|
||||||
|
"event_id": self.event_id_2,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tok=self.third_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# An event with no relations.
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "No relations"},
|
||||||
|
tok=self.tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
|
||||||
|
"""Make a request to /messages with a filter, returns the chunk of events."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
|
||||||
|
access_token=self.tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
|
return channel.json_body["chunk"]
|
||||||
|
|
||||||
|
def test_filter_relation_senders(self):
|
||||||
|
# Messages which second user reacted to.
|
||||||
|
filter = {"io.element.relation_senders": [self.second_user_id]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
|
||||||
|
|
||||||
|
# Messages which third user reacted to.
|
||||||
|
filter = {"io.element.relation_senders": [self.third_user_id]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
|
||||||
|
|
||||||
|
# Messages which either user reacted to.
|
||||||
|
filter = {
|
||||||
|
"io.element.relation_senders": [self.second_user_id, self.third_user_id]
|
||||||
|
}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 2, chunk)
|
||||||
|
self.assertCountEqual(
|
||||||
|
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_relation_type(self):
|
||||||
|
# Messages which have annotations.
|
||||||
|
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
|
||||||
|
|
||||||
|
# Messages which have references.
|
||||||
|
filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
|
||||||
|
|
||||||
|
# Messages which have either annotations or references.
|
||||||
|
filter = {
|
||||||
|
"io.element.relation_types": [
|
||||||
|
RelationTypes.ANNOTATION,
|
||||||
|
RelationTypes.REFERENCE,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 2, chunk)
|
||||||
|
self.assertCountEqual(
|
||||||
|
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_relation_senders_and_type(self):
|
||||||
|
# Messages which second user reacted to.
|
||||||
|
filter = {
|
||||||
|
"io.element.relation_senders": [self.second_user_id],
|
||||||
|
"io.element.relation_types": [RelationTypes.ANNOTATION],
|
||||||
|
}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
|
||||||
|
|
||||||
|
|
||||||
class ContextTestCase(unittest.HomeserverTestCase):
|
class ContextTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
|
|
|
@ -0,0 +1,207 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes, RelationTypes
|
||||||
|
from synapse.api.filtering import Filter
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationTestCase(HomeserverTestCase):
|
||||||
|
"""
|
||||||
|
Test the pre-filtering done in the pagination code.
|
||||||
|
|
||||||
|
This is similar to some of the tests in tests.rest.client.test_rooms but here
|
||||||
|
we ensure that the filtering done in the database is applied successfully.
|
||||||
|
"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets_for_client_rest_resource,
|
||||||
|
room.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
|
config["experimental_features"] = {"msc3440_enabled": True}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, homeserver):
|
||||||
|
self.user_id = self.register_user("test", "test")
|
||||||
|
self.tok = self.login("test", "test")
|
||||||
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||||
|
|
||||||
|
self.second_user_id = self.register_user("second", "test")
|
||||||
|
self.second_tok = self.login("second", "test")
|
||||||
|
self.helper.join(
|
||||||
|
room=self.room_id, user=self.second_user_id, tok=self.second_tok
|
||||||
|
)
|
||||||
|
|
||||||
|
self.third_user_id = self.register_user("third", "test")
|
||||||
|
self.third_tok = self.login("third", "test")
|
||||||
|
self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
|
||||||
|
|
||||||
|
# An initial event with a relation from second user.
|
||||||
|
res = self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "Message 1"},
|
||||||
|
tok=self.tok,
|
||||||
|
)
|
||||||
|
self.event_id_1 = res["event_id"]
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.reaction",
|
||||||
|
content={
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": RelationTypes.ANNOTATION,
|
||||||
|
"event_id": self.event_id_1,
|
||||||
|
"key": "👍",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tok=self.second_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Another event with a relation from third user.
|
||||||
|
res = self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "Message 2"},
|
||||||
|
tok=self.tok,
|
||||||
|
)
|
||||||
|
self.event_id_2 = res["event_id"]
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.reaction",
|
||||||
|
content={
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": RelationTypes.REFERENCE,
|
||||||
|
"event_id": self.event_id_2,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tok=self.third_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# An event with no relations.
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
content={"msgtype": "m.text", "body": "No relations"},
|
||||||
|
tok=self.tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
|
||||||
|
"""Make a request to /messages with a filter, returns the chunk of events."""
|
||||||
|
|
||||||
|
from_token = self.get_success(
|
||||||
|
self.hs.get_event_sources().get_current_token_for_pagination()
|
||||||
|
)
|
||||||
|
|
||||||
|
events, next_key = self.get_success(
|
||||||
|
self.hs.get_datastore().paginate_room_events(
|
||||||
|
room_id=self.room_id,
|
||||||
|
from_key=from_token.room_key,
|
||||||
|
to_key=None,
|
||||||
|
direction="b",
|
||||||
|
limit=10,
|
||||||
|
event_filter=Filter(self.hs, filter),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def test_filter_relation_senders(self):
|
||||||
|
# Messages which second user reacted to.
|
||||||
|
filter = {"io.element.relation_senders": [self.second_user_id]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0].event_id, self.event_id_1)
|
||||||
|
|
||||||
|
# Messages which third user reacted to.
|
||||||
|
filter = {"io.element.relation_senders": [self.third_user_id]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0].event_id, self.event_id_2)
|
||||||
|
|
||||||
|
# Messages which either user reacted to.
|
||||||
|
filter = {
|
||||||
|
"io.element.relation_senders": [self.second_user_id, self.third_user_id]
|
||||||
|
}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 2, chunk)
|
||||||
|
self.assertCountEqual(
|
||||||
|
[c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_relation_type(self):
|
||||||
|
# Messages which have annotations.
|
||||||
|
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0].event_id, self.event_id_1)
|
||||||
|
|
||||||
|
# Messages which have references.
|
||||||
|
filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0].event_id, self.event_id_2)
|
||||||
|
|
||||||
|
# Messages which have either annotations or references.
|
||||||
|
filter = {
|
||||||
|
"io.element.relation_types": [
|
||||||
|
RelationTypes.ANNOTATION,
|
||||||
|
RelationTypes.REFERENCE,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 2, chunk)
|
||||||
|
self.assertCountEqual(
|
||||||
|
[c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_filter_relation_senders_and_type(self):
|
||||||
|
# Messages which second user reacted to.
|
||||||
|
filter = {
|
||||||
|
"io.element.relation_senders": [self.second_user_id],
|
||||||
|
"io.element.relation_types": [RelationTypes.ANNOTATION],
|
||||||
|
}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0].event_id, self.event_id_1)
|
||||||
|
|
||||||
|
def test_duplicate_relation(self):
|
||||||
|
"""An event should only be returned once if there are multiple relations to it."""
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.reaction",
|
||||||
|
content={
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": RelationTypes.ANNOTATION,
|
||||||
|
"event_id": self.event_id_1,
|
||||||
|
"key": "A",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
tok=self.second_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
filter = {"io.element.relation_senders": [self.second_user_id]}
|
||||||
|
chunk = self._filter_messages(filter)
|
||||||
|
self.assertEqual(len(chunk), 1, chunk)
|
||||||
|
self.assertEqual(chunk[0].event_id, self.event_id_1)
|
Loading…
Reference in New Issue