Additional tests for third-party event rules (#8468)

* Optimise and test state fetching for 3p event rules

Getting all the events at once is much more efficient than getting them
individually

* Test that 3p event rules can modify events
This commit is contained in:
Richard van der Hoff 2020-10-06 16:31:31 +01:00 committed by GitHub
parent 9c0b168cff
commit a024461130
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 18 deletions

1
changelog.d/8468.misc Normal file
View File

@ -0,0 +1 @@
Additional testing for `ThirdPartyEventRules`.

View File

@ -61,12 +61,14 @@ class ThirdPartyEventRules:
prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database.
state_events = {}
for key, event_id in prev_state_ids.items():
state_events[key] = await self.store.get_event(event_id, allow_none=True)
events = await self.store.get_events(prev_state_ids.values())
state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret
# The module can modify the event slightly if it wants, but caution should be
# exercised, and it's likely to go very wrong if applied to events received over
# federation.
return await self.third_party_rules.check_event_allowed(event, state_events)
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool

View File

@ -12,33 +12,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from mock import Mock
from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import Requester
from synapse.types import Requester, StateMap
from tests import unittest
thread_local = threading.local()
class ThirdPartyRulesTestModule:
def __init__(self, config, *args, **kwargs):
pass
def __init__(self, config, module_api):
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
return True
async def check_event_allowed(self, event, context):
if event.type == "foo.bar.forbidden":
return False
else:
return True
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
return True
@staticmethod
def parse_config(config):
return config
def current_rules_module() -> ThirdPartyRulesTestModule:
return thread_local.rules_module
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
@ -46,15 +56,13 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
def default_config(self):
config = super().default_config()
config["third_party_event_rules"] = {
"module": __name__ + ".ThirdPartyRulesTestModule",
"config": {},
}
self.hs = self.setup_test_homeserver(config=config)
return self.hs
return config
def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests
@ -67,6 +75,14 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
return ev.type != "foo.bar.forbidden"
callback = Mock(spec=[], side_effect=check)
current_rules_module().check_event_allowed = callback
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
@ -76,6 +92,16 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
callback.assert_called_once()
# there should be various state events in the state arg: do some basic checks
state_arg = callback.call_args[0][1]
for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
self.assertIn(k, state_arg)
ev = state_arg[k]
self.assertEqual(ev.type, k[0])
self.assertEqual(ev.state_key, k[1])
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
@ -84,3 +110,35 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_modify_event(self):
"""Tests that the module can successfully tweak an event before it is persisted.
"""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
ev.content = {"x": "y"}
return True
current_rules_module().check_event_allowed = check
# now send the event
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.result["code"], b"200", channel.result)
event_id = channel.json_body["event_id"]
# ... and check that it got modified
request, channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.result["code"], b"200", channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")