Recursively fetch the thread for receipts & notifications. (#13824)
Consider an event to be part of a thread if you can follow a chain of relations up to a thread root. Part of MSC3773 & MSC3771.
This commit is contained in:
parent
3e74ad20db
commit
2b6d41ebd6
|
@ -0,0 +1 @@
|
|||
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
|
@ -286,8 +286,13 @@ class BulkPushRuleEvaluator:
|
|||
relation.parent_id,
|
||||
itertools.chain(*(r.rules() for r in rules_by_user.values())),
|
||||
)
|
||||
# Recursively attempt to find the thread this event relates to.
|
||||
if relation.rel_type == RelationTypes.THREAD:
|
||||
thread_id = relation.parent_id
|
||||
else:
|
||||
# Since the event has not yet been persisted we check whether
|
||||
# the parent is part of a thread.
|
||||
thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
|
||||
|
||||
evaluator = PushRuleEvaluator(
|
||||
_flatten_dict(event),
|
||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet):
|
|||
self.receipts_handler = hs.get_receipts_handler()
|
||||
self.read_marker_handler = hs.get_read_marker_handler()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self._main_store = hs.get_datastores().main
|
||||
|
||||
self._known_receipt_types = {
|
||||
ReceiptTypes.READ,
|
||||
|
@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet):
|
|||
thread_id = body.get("thread_id")
|
||||
if not thread_id or not isinstance(thread_id, str):
|
||||
raise SynapseError(
|
||||
400, "thread_id field must be a non-empty string"
|
||||
400,
|
||||
"thread_id field must be a non-empty string",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if receipt_type == ReceiptTypes.FULLY_READ:
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
# Ensure the event ID roughly correlates to the thread ID.
|
||||
if thread_id != await self._main_store.get_thread_id(event_id):
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"event_id {event_id} is not related to thread {thread_id}",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
|
|
@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
"get_event_relations", _get_event_relations
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def get_thread_id(self, event_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the thread ID for an event. This considers multi-level relations,
|
||||
e.g. an annotation to an event which is part of a thread.
|
||||
|
||||
Args:
|
||||
event_id: The event ID to fetch the thread ID for.
|
||||
|
||||
Returns:
|
||||
The event ID of the root event in the thread, if this event is part
|
||||
of a thread. None, otherwise.
|
||||
"""
|
||||
# Since event relations form a tree, we should only ever find 0 or 1
|
||||
# results from the below query.
|
||||
sql = """
|
||||
WITH RECURSIVE related_events AS (
|
||||
SELECT event_id, relates_to_id, relation_type
|
||||
FROM event_relations
|
||||
WHERE event_id = ?
|
||||
UNION SELECT e.event_id, e.relates_to_id, e.relation_type
|
||||
FROM event_relations e
|
||||
INNER JOIN related_events r ON r.relates_to_id = e.event_id
|
||||
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
|
||||
"""
|
||||
|
||||
def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
|
||||
txn.execute(sql, (event_id,))
|
||||
# TODO Should we ensure there's only a single result here?
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
return None
|
||||
|
||||
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
|
||||
|
||||
|
||||
class RelationsStore(RelationsWorkerStore):
|
||||
pass
|
||||
|
|
|
@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
|
|||
_rotate()
|
||||
_assert_counts(0, 0, 0, 0)
|
||||
|
||||
def test_recursive_thread(self) -> None:
|
||||
"""
|
||||
Events related to events in a thread should still be considered part of
|
||||
that thread.
|
||||
"""
|
||||
|
||||
# Create a user to receive notifications and send receipts.
|
||||
user_id = self.register_user("user1235", "pass")
|
||||
token = self.login("user1235", "pass")
|
||||
|
||||
# And another users to send events.
|
||||
other_id = self.register_user("other", "pass")
|
||||
other_token = self.login("other", "pass")
|
||||
|
||||
# Create a room and put both users in it.
|
||||
room_id = self.helper.create_room_as(user_id, tok=token)
|
||||
self.helper.join(room_id, other_id, tok=other_token)
|
||||
|
||||
# Update the user's push rules to care about reaction events.
|
||||
self.get_success(
|
||||
self.store.add_push_rule(
|
||||
user_id,
|
||||
"related_events",
|
||||
priority_class=5,
|
||||
conditions=[
|
||||
{"kind": "event_match", "key": "type", "pattern": "m.reaction"}
|
||||
],
|
||||
actions=["notify"],
|
||||
)
|
||||
)
|
||||
|
||||
def _create_event(type: str, content: JsonDict) -> str:
|
||||
result = self.helper.send_event(
|
||||
room_id, type=type, content=content, tok=other_token
|
||||
)
|
||||
return result["event_id"]
|
||||
|
||||
def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
|
||||
counts = self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"get-unread-counts",
|
||||
self.store._get_unread_counts_by_receipt_txn,
|
||||
room_id,
|
||||
user_id,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
counts.main_timeline,
|
||||
NotifCounts(
|
||||
notify_count=noitf_count, unread_count=0, highlight_count=0
|
||||
),
|
||||
)
|
||||
if thread_notif_count:
|
||||
self.assertEqual(
|
||||
counts.threads,
|
||||
{
|
||||
thread_id: NotifCounts(
|
||||
notify_count=thread_notif_count,
|
||||
unread_count=0,
|
||||
highlight_count=0,
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.assertEqual(counts.threads, {})
|
||||
|
||||
# Create a root event.
|
||||
thread_id = _create_event(
|
||||
"m.room.message", {"msgtype": "m.text", "body": "msg"}
|
||||
)
|
||||
_assert_counts(1, 0)
|
||||
|
||||
# Reply, creating a thread.
|
||||
reply_id = _create_event(
|
||||
"m.room.message",
|
||||
{
|
||||
"msgtype": "m.text",
|
||||
"body": "msg",
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": thread_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
_assert_counts(1, 1)
|
||||
|
||||
# Create an event related to a thread event, this should still appear in
|
||||
# the thread.
|
||||
_create_event(
|
||||
type="m.reaction",
|
||||
content={
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.annotation",
|
||||
"event_id": reply_id,
|
||||
"key": "A",
|
||||
}
|
||||
},
|
||||
)
|
||||
_assert_counts(1, 2)
|
||||
|
||||
def test_find_first_stream_ordering_after_ts(self) -> None:
|
||||
def add_event(so: int, ts: int) -> None:
|
||||
self.get_success(
|
||||
|
|
Loading…
Reference in New Issue