From 1f4b960e62678ba92506af6872b30512b08242a4 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 1 Jun 2023 15:55:15 +0200 Subject: [PATCH] argggggghhhh --- synapse/handlers/room_list.py | 33 +++-- tests/module_api/test_fetch_public_rooms.py | 2 - tests/rest/client/test_public_rooms.py | 148 ++++++++++++++++++++ 3 files changed, 173 insertions(+), 10 deletions(-) create mode 100644 tests/rest/client/test_public_rooms.py diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 6cec313bf4..49a7df973f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -162,6 +162,7 @@ class RoomListHandler: last_module_index = None if since_token: batch_token = RoomListNextBatch.from_token(since_token) + print(batch_token) forwards = batch_token.direction_is_forward last_joined_members = batch_token.last_joined_members last_room_id = batch_token.last_room_id @@ -192,17 +193,24 @@ class RoomListHandler: room ) - for module_index, fetch_public_rooms in enumerate( - self._module_api_callbacks.fetch_public_rooms_callbacks - ): + nb_modules = len(self._module_api_callbacks.fetch_public_rooms_callbacks) + + module_range = range(0, nb_modules) + # if not forwards: + # module_range = reversed(module_range) + + for module_index in module_range: + fetch_public_rooms = self._module_api_callbacks.fetch_public_rooms_callbacks[module_index] # Ask each module for a list of public rooms given the last_joined_members # value from the since token and the probing limit # last_joined_members needs to be reduce by one if this module has already # given its result for last_joined_members module_last_joined_members = last_joined_members if module_last_joined_members is not None and last_module_index is not None: - if module_index < last_module_index: + if forwards and module_index < last_module_index: module_last_joined_members = module_last_joined_members - 1 + # if not forwards and module_index > last_module_index: + # module_last_joined_members = module_last_joined_members - 1 module_public_rooms = await fetch_public_rooms( network_tuple, @@ -226,19 +234,28 @@ class RoomListHandler: results = [] for num_joined_members in nums_joined_members: - results += num_joined_members_buckets[num_joined_members] + rooms = num_joined_members_buckets[num_joined_members] + # if not forwards: + # rooms.reverse() + results += rooms + + + print([(r.room_id, r.num_joined_members) for r in results]) response: JsonDict = {} num_results = len(results) if limit is not None and probing_limit is not None: more_to_come = num_results >= probing_limit - results = results[:limit] + # Depending on direction we trim either the front or back. + if forwards: + results = results[:limit] + else: + results = results[-limit:] else: more_to_come = False - if not forwards: - results.reverse() + print([(r.room_id, r.num_joined_members) for r in results]) if num_results > 0: final_entry = results[-1] diff --git a/tests/module_api/test_fetch_public_rooms.py b/tests/module_api/test_fetch_public_rooms.py index 0ccf68e4b2..8daf8c5c40 100644 --- a/tests/module_api/test_fetch_public_rooms.py +++ b/tests/module_api/test_fetch_public_rooms.py @@ -150,10 +150,8 @@ class FetchPublicRoomsTestCase(HomeserverTestCase): user = self.register_user("alice", "pass") token = self.login(user, "pass") - user2 = self.register_user("alice2", "pass") token2 = self.login(user2, "pass") - user3 = self.register_user("alice3", "pass") token3 = self.login(user3, "pass") diff --git a/tests/rest/client/test_public_rooms.py b/tests/rest/client/test_public_rooms.py new file mode 100644 index 0000000000..ec5b19c33f --- /dev/null +++ b/tests/rest/client/test_public_rooms.py @@ -0,0 +1,148 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin, login, room +from synapse.server import HomeServer +from synapse.types import PublicRoom, ThirdPartyInstanceID +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class PublicRoomsTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + self.hs = self.setup_test_homeserver(config=config) + self.url = "/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self._store = homeserver.get_datastores().main + + user = self.register_user("alice", "pass") + token = self.login(user, "pass") + user2 = self.register_user("alice2", "pass") + token2 = self.login(user2, "pass") + user3 = self.register_user("alice3", "pass") + token3 = self.login(user3, "pass") + + # Create 10 rooms + for _ in range(3): + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + + for _ in range(3): + room_id = self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + self.helper.join(room_id, user2, tok=token2) + + for _ in range(4): + room_id = self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + self.helper.join(room_id, user2, tok=token2) + self.helper.join(room_id, user3, tok=token3) + + def test_no_limit(self) -> None: + channel = self.make_request("GET", self.url) + chunk = channel.json_body["chunk"] + + self.assertEquals(len(chunk), 10) + + def test_pagination_limit_1(self) -> None: + returned_rooms = set() + + for i in range(10): + next_batch = None if i == 0 else channel.json_body["next_batch"] + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 1) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + + self.assertNotIn("next_batch", channel.json_body) + + returned_rooms = set() + returned_rooms.add(chunk[0]["room_id"]) + + for i in range(9): + print(i) + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 1) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + + def test_pagination_limit_2(self) -> None: + returned_rooms = set() + + for i in range(5): + next_batch = None if i == 0 else channel.json_body["next_batch"] + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 2) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + print(chunk[1]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[1]["room_id"]) + + self.assertNotIn("next_batch", channel.json_body) + + returned_rooms = set() + returned_rooms.add(chunk[0]["room_id"]) + returned_rooms.add(chunk[1]["room_id"]) + + for i in range(4): + print(i) + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 2) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + print(chunk[1]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[1]["room_id"])