This commit is contained in:
Mathieu Velten 2023-06-23 15:22:00 +02:00
parent 5047c01d3f
commit 921fa8f9ce
2 changed files with 88 additions and 11 deletions

View File

@ -816,7 +816,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
local_by_user_then_device = {} local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items(): for user_id, messages_by_device in messages_by_user_then_device.items():
inboxes_size = {} inbox_sizes = {}
if size_limit: if size_limit:
sql = """ sql = """
SELECT device_id, COUNT(*) FROM device_inbox SELECT device_id, COUNT(*) FROM device_inbox
@ -825,7 +825,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
""" """
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
for r in txn: for r in txn:
inboxes_size[r[0]] = r[1] inbox_sizes[r[0]] = r[1]
messages_json_for_user = {} messages_json_for_user = {}
devices = list(messages_by_device.keys()) devices = list(messages_by_device.keys())
@ -842,10 +842,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
message_json = json_encoder.encode(messages_by_device["*"]) message_json = json_encoder.encode(messages_by_device["*"])
for device_id in devices: for device_id in devices:
if ( if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
size_limit is None
or inboxes_size.get(device_id, 0) <= size_limit
):
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
messages_json_for_user[device_id] = message_json messages_json_for_user[device_id] = message_json
@ -881,10 +878,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
message_json = json_encoder.encode(msg) message_json = json_encoder.encode(msg)
if ( if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
size_limit is None
or inboxes_size.get(device_id, 0) <= size_limit
):
messages_json_for_user[device_id] = message_json messages_json_for_user[device_id] = message_json
if messages_json_for_user: if messages_json_for_user:

View File

@ -23,20 +23,28 @@ from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.handlers.devicemessage import INBOX_SIZE_LIMIT_FOR_KEY_REQUEST
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict from synapse.types import JsonDict, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
import synapse
user1 = "@boris:aaa" user1 = "@boris:aaa"
user2 = "@theresa:bbb" user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase): class DeviceTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock() self.appservice_api = mock.Mock()
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
@ -47,6 +55,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler) assert isinstance(handler, DeviceHandler)
self.handler = handler self.handler = handler
self.msg_handler = hs.get_device_message_handler()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
@ -398,6 +408,79 @@ class DeviceTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_room_key_request_limit(self) -> None:
store = self.hs.get_datastores().main
myuser = self.register_user("myuser", "pass")
self.login("myuser", "pass", "device")
self.login("myuser", "pass", "device2")
requester = requester = create_requester(myuser)
from_token = self.event_sources.get_current_token()
# for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
# self.get_success(
# self.msg_handler.send_device_message(
# requester,
# "m.room_key",
# {
# myuser2: {
# "device": {
# "algorithm": "m.megolm.v1.aes-sha2",
# "room_id": "!Cuyf34gef24t:localhost",
# "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
# "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY..."
# }
# }
# },
# )
# )
# to_token = self.event_sources.get_current_token()
# res = self.get_success(self.store.get_messages_for_device(
# myuser2,
# "device",
# from_token.to_device_key,
# to_token.to_device_key,
# INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
# ))
# self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2)
# from_token = to_token
for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
self.get_success(
self.msg_handler.send_device_message(
requester,
"m.room_key_request",
{
myuser: {
"device2": {
"action": "request",
"request_id": f"request_id_{i}",
"requesting_device_id": "device",
}
}
},
)
)
to_token = self.event_sources.get_current_token()
res = self.get_success(
self.store.get_messages_for_device(
myuser,
"device2",
from_token.to_device_key,
to_token.to_device_key,
INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
)
)
self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST)
class DehydrationTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: