Implement an `on_new_event` callback (#11126)
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
parent
7004f43da1
commit
c7a5e49664
|
@ -0,0 +1 @@
|
||||||
|
Add an `on_new_event` third-party rules callback to allow Synapse modules to act after an event has been sent into a room.
|
|
@ -119,6 +119,27 @@ callback returns `True`, Synapse falls through to the next one. The value of the
|
||||||
callback that does not return `True` will be used. If this happens, Synapse will not call
|
callback that does not return `True` will be used. If this happens, Synapse will not call
|
||||||
any of the subsequent implementations of this callback.
|
any of the subsequent implementations of this callback.
|
||||||
|
|
||||||
|
### `on_new_event`
|
||||||
|
|
||||||
|
_First introduced in Synapse v1.47.0_
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def on_new_event(
|
||||||
|
event: "synapse.events.EventBase",
|
||||||
|
state_events: "synapse.types.StateMap",
|
||||||
|
) -> None:
|
||||||
|
```
|
||||||
|
|
||||||
|
Called after sending an event into a room. The module is passed the event, as well
|
||||||
|
as the state of the room _after_ the event. This means that if the event is a state event,
|
||||||
|
it will be included in this state.
|
||||||
|
|
||||||
|
Note that this callback is called when the event has already been processed and stored
|
||||||
|
into the room, which means this callback cannot be used to deny persisting the event. To
|
||||||
|
deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#check_event_for_spam) instead.
|
||||||
|
|
||||||
|
If multiple modules implement this callback, Synapse runs them all in order.
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
The example below is a module that implements the third-party rules callback
|
The example below is a module that implements the third-party rules callback
|
||||||
|
|
|
@ -36,6 +36,7 @@ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
|
||||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
|
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
|
||||||
[str, StateMap[EventBase], str], Awaitable[bool]
|
[str, StateMap[EventBase], str], Awaitable[bool]
|
||||||
]
|
]
|
||||||
|
ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
|
||||||
|
|
||||||
|
|
||||||
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
|
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
|
||||||
|
@ -152,6 +153,7 @@ class ThirdPartyEventRules:
|
||||||
self._check_visibility_can_be_modified_callbacks: List[
|
self._check_visibility_can_be_modified_callbacks: List[
|
||||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
||||||
] = []
|
] = []
|
||||||
|
self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
|
||||||
|
|
||||||
def register_third_party_rules_callbacks(
|
def register_third_party_rules_callbacks(
|
||||||
self,
|
self,
|
||||||
|
@ -163,6 +165,7 @@ class ThirdPartyEventRules:
|
||||||
check_visibility_can_be_modified: Optional[
|
check_visibility_can_be_modified: Optional[
|
||||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
|
on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register callbacks from modules for each hook."""
|
"""Register callbacks from modules for each hook."""
|
||||||
if check_event_allowed is not None:
|
if check_event_allowed is not None:
|
||||||
|
@ -181,6 +184,9 @@ class ThirdPartyEventRules:
|
||||||
check_visibility_can_be_modified,
|
check_visibility_can_be_modified,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if on_new_event is not None:
|
||||||
|
self._on_new_event_callbacks.append(on_new_event)
|
||||||
|
|
||||||
async def check_event_allowed(
|
async def check_event_allowed(
|
||||||
self, event: EventBase, context: EventContext
|
self, event: EventBase, context: EventContext
|
||||||
) -> Tuple[bool, Optional[dict]]:
|
) -> Tuple[bool, Optional[dict]]:
|
||||||
|
@ -321,6 +327,31 @@ class ThirdPartyEventRules:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def on_new_event(self, event_id: str) -> None:
|
||||||
|
"""Let modules act on events after they've been sent (e.g. auto-accepting
|
||||||
|
invites, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: The ID of the event.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ModuleFailureError if a callback raised any exception.
|
||||||
|
"""
|
||||||
|
# Bail out early without hitting the store if we don't have any callbacks
|
||||||
|
if len(self._on_new_event_callbacks) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
event = await self.store.get_event(event_id)
|
||||||
|
state_events = await self._get_state_map_for_room(event.room_id)
|
||||||
|
|
||||||
|
for callback in self._on_new_event_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(event, state_events)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to run module API callback %s: %s", callback, e
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
|
async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
|
||||||
"""Given a room ID, return the state events of that room.
|
"""Given a room ID, return the state events of that room.
|
||||||
|
|
||||||
|
|
|
@ -1916,7 +1916,7 @@ class FederationEventHandler:
|
||||||
event_pos = PersistedEventPosition(
|
event_pos = PersistedEventPosition(
|
||||||
self._instance_name, event.internal_metadata.stream_ordering
|
self._instance_name, event.internal_metadata.stream_ordering
|
||||||
)
|
)
|
||||||
self._notifier.on_new_room_event(
|
await self._notifier.on_new_room_event(
|
||||||
event, event_pos, max_stream_token, extra_users=extra_users
|
event, event_pos, max_stream_token, extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1537,13 +1537,16 @@ class EventCreationHandler:
|
||||||
# If there's an expiry timestamp on the event, schedule its expiry.
|
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||||
self._message_handler.maybe_schedule_expiry(event)
|
self._message_handler.maybe_schedule_expiry(event)
|
||||||
|
|
||||||
def _notify() -> None:
|
async def _notify() -> None:
|
||||||
try:
|
try:
|
||||||
self.notifier.on_new_room_event(
|
await self.notifier.on_new_room_event(
|
||||||
event, event_pos, max_stream_token, extra_users=extra_users
|
event, event_pos, max_stream_token, extra_users=extra_users
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error notifying about new room event")
|
logger.exception(
|
||||||
|
"Error notifying about new room event %s",
|
||||||
|
event.event_id,
|
||||||
|
)
|
||||||
|
|
||||||
run_in_background(_notify)
|
run_in_background(_notify)
|
||||||
|
|
||||||
|
|
|
@ -220,6 +220,8 @@ class Notifier:
|
||||||
# down.
|
# down.
|
||||||
self.remote_server_up_callbacks: List[Callable[[str], None]] = []
|
self.remote_server_up_callbacks: List[Callable[[str], None]] = []
|
||||||
|
|
||||||
|
self._third_party_rules = hs.get_third_party_event_rules()
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.appservice_handler = hs.get_application_service_handler()
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
self._pusher_pool = hs.get_pusherpool()
|
self._pusher_pool = hs.get_pusherpool()
|
||||||
|
@ -267,7 +269,7 @@ class Notifier:
|
||||||
"""
|
"""
|
||||||
self.replication_callbacks.append(cb)
|
self.replication_callbacks.append(cb)
|
||||||
|
|
||||||
def on_new_room_event(
|
async def on_new_room_event(
|
||||||
self,
|
self,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
event_pos: PersistedEventPosition,
|
event_pos: PersistedEventPosition,
|
||||||
|
@ -275,9 +277,10 @@ class Notifier:
|
||||||
extra_users: Optional[Collection[UserID]] = None,
|
extra_users: Optional[Collection[UserID]] = None,
|
||||||
):
|
):
|
||||||
"""Unwraps event and calls `on_new_room_event_args`."""
|
"""Unwraps event and calls `on_new_room_event_args`."""
|
||||||
self.on_new_room_event_args(
|
await self.on_new_room_event_args(
|
||||||
event_pos=event_pos,
|
event_pos=event_pos,
|
||||||
room_id=event.room_id,
|
room_id=event.room_id,
|
||||||
|
event_id=event.event_id,
|
||||||
event_type=event.type,
|
event_type=event.type,
|
||||||
state_key=event.get("state_key"),
|
state_key=event.get("state_key"),
|
||||||
membership=event.content.get("membership"),
|
membership=event.content.get("membership"),
|
||||||
|
@ -285,9 +288,10 @@ class Notifier:
|
||||||
extra_users=extra_users or [],
|
extra_users=extra_users or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_new_room_event_args(
|
async def on_new_room_event_args(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
|
event_id: str,
|
||||||
event_type: str,
|
event_type: str,
|
||||||
state_key: Optional[str],
|
state_key: Optional[str],
|
||||||
membership: Optional[str],
|
membership: Optional[str],
|
||||||
|
@ -302,7 +306,10 @@ class Notifier:
|
||||||
listening to the room, and any listeners for the users in the
|
listening to the room, and any listeners for the users in the
|
||||||
`extra_users` param.
|
`extra_users` param.
|
||||||
|
|
||||||
The events can be peristed out of order. The notifier will wait
|
This also notifies modules listening on new events via the
|
||||||
|
`on_new_event` callback.
|
||||||
|
|
||||||
|
The events can be persisted out of order. The notifier will wait
|
||||||
until all previous events have been persisted before notifying
|
until all previous events have been persisted before notifying
|
||||||
the client streams.
|
the client streams.
|
||||||
"""
|
"""
|
||||||
|
@ -318,6 +325,8 @@ class Notifier:
|
||||||
)
|
)
|
||||||
self._notify_pending_new_room_events(max_room_stream_token)
|
self._notify_pending_new_room_events(max_room_stream_token)
|
||||||
|
|
||||||
|
await self._third_party_rules.on_new_event(event_id)
|
||||||
|
|
||||||
self.notify_replication()
|
self.notify_replication()
|
||||||
|
|
||||||
def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
|
def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
|
||||||
|
|
|
@ -207,11 +207,12 @@ class ReplicationDataHandler:
|
||||||
|
|
||||||
max_token = self.store.get_room_max_token()
|
max_token = self.store.get_room_max_token()
|
||||||
event_pos = PersistedEventPosition(instance_name, token)
|
event_pos = PersistedEventPosition(instance_name, token)
|
||||||
self.notifier.on_new_room_event_args(
|
await self.notifier.on_new_room_event_args(
|
||||||
event_pos=event_pos,
|
event_pos=event_pos,
|
||||||
max_room_stream_token=max_token,
|
max_room_stream_token=max_token,
|
||||||
extra_users=extra_users,
|
extra_users=extra_users,
|
||||||
room_id=row.data.room_id,
|
room_id=row.data.room_id,
|
||||||
|
event_id=row.data.event_id,
|
||||||
event_type=row.data.type,
|
event_type=row.data.type,
|
||||||
state_key=row.data.state_key,
|
state_key=row.data.state_key,
|
||||||
membership=row.data.membership,
|
membership=row.data.membership,
|
||||||
|
|
|
@ -15,7 +15,7 @@ import threading
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||||
|
@ -25,6 +25,7 @@ from synapse.types import JsonDict, Requester, StateMap
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
|
@ -74,7 +75,7 @@ class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
admin.register_servlets,
|
admin.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
@ -86,11 +87,29 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
load_legacy_third_party_event_rules(hs)
|
load_legacy_third_party_event_rules(hs)
|
||||||
|
|
||||||
|
# We're not going to be properly signing events as our remote homeserver is fake,
|
||||||
|
# therefore disable event signature checks.
|
||||||
|
# Note that these checks are not relevant to this test case.
|
||||||
|
|
||||||
|
# Have this homeserver auto-approve all event signature checking.
|
||||||
|
async def approve_all_signature_checking(_, pdu):
|
||||||
|
return pdu
|
||||||
|
|
||||||
|
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
|
||||||
|
|
||||||
|
# Have this homeserver skip event auth checks. This is necessary due to
|
||||||
|
# event auth checks ensuring that events were signed by the sender's homeserver.
|
||||||
|
async def _check_event_auth(origin, event, context, *args, **kwargs):
|
||||||
|
return context
|
||||||
|
|
||||||
|
hs.get_federation_event_handler()._check_event_auth = _check_event_auth
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor, clock, homeserver):
|
||||||
# Create a user and room to play with during the tests
|
# Create some users and a room to play with during the tests
|
||||||
self.user_id = self.register_user("kermit", "monkey")
|
self.user_id = self.register_user("kermit", "monkey")
|
||||||
|
self.invitee = self.register_user("invitee", "hackme")
|
||||||
self.tok = self.login("kermit", "monkey")
|
self.tok = self.login("kermit", "monkey")
|
||||||
|
|
||||||
# Some tests might prevent room creation on purpose.
|
# Some tests might prevent room creation on purpose.
|
||||||
|
@ -424,6 +443,74 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body["i"], i)
|
self.assertEqual(channel.json_body["i"], i)
|
||||||
|
|
||||||
|
def test_on_new_event(self):
|
||||||
|
"""Test that the on_new_event callback is called on new events"""
|
||||||
|
on_new_event = Mock(make_awaitable(None))
|
||||||
|
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
|
||||||
|
on_new_event
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send a message event to the room and check that the callback is called.
|
||||||
|
self.helper.send(room_id=self.room_id, tok=self.tok)
|
||||||
|
self.assertEqual(on_new_event.call_count, 1)
|
||||||
|
|
||||||
|
# Check that the callback is also called on membership updates.
|
||||||
|
self.helper.invite(
|
||||||
|
room=self.room_id,
|
||||||
|
src=self.user_id,
|
||||||
|
targ=self.invitee,
|
||||||
|
tok=self.tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(on_new_event.call_count, 2)
|
||||||
|
|
||||||
|
args, _ = on_new_event.call_args
|
||||||
|
|
||||||
|
self.assertEqual(args[0].membership, Membership.INVITE)
|
||||||
|
self.assertEqual(args[0].state_key, self.invitee)
|
||||||
|
|
||||||
|
# Check that the invitee's membership is correct in the state that's passed down
|
||||||
|
# to the callback.
|
||||||
|
self.assertEqual(
|
||||||
|
args[1][(EventTypes.Member, self.invitee)].membership,
|
||||||
|
Membership.INVITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send an event over federation and check that the callback is also called.
|
||||||
|
self._send_event_over_federation()
|
||||||
|
self.assertEqual(on_new_event.call_count, 3)
|
||||||
|
|
||||||
|
def _send_event_over_federation(self) -> None:
|
||||||
|
"""Send a dummy event over federation and check that the request succeeds."""
|
||||||
|
body = {
|
||||||
|
"origin": self.hs.config.server.server_name,
|
||||||
|
"origin_server_ts": self.clock.time_msec(),
|
||||||
|
"pdus": [
|
||||||
|
{
|
||||||
|
"sender": self.user_id,
|
||||||
|
"type": EventTypes.Message,
|
||||||
|
"state_key": "",
|
||||||
|
"content": {"body": "hello world", "msgtype": "m.text"},
|
||||||
|
"room_id": self.room_id,
|
||||||
|
"depth": 0,
|
||||||
|
"origin_server_ts": self.clock.time_msec(),
|
||||||
|
"prev_events": [],
|
||||||
|
"auth_events": [],
|
||||||
|
"signatures": {},
|
||||||
|
"unsigned": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
method="PUT",
|
||||||
|
path="/_matrix/federation/v1/send/1",
|
||||||
|
content=body,
|
||||||
|
federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
def _update_power_levels(self, event_default: int = 0):
|
def _update_power_levels(self, event_default: int = 0):
|
||||||
"""Updates the room's power levels.
|
"""Updates the room's power levels.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue