From dcc49cd1aeb312783b8f66113ed0c42c4bb6e2f1 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Tue, 30 May 2023 17:41:43 +0200 Subject: [PATCH] More tests, less bugs --- synapse/handlers/room_list.py | 29 ++- .../callbacks/public_rooms_callbacks.py | 2 +- tests/module_api/test_fetch_public_rooms.py | 201 +++++++++--------- 3 files changed, 113 insertions(+), 119 deletions(-) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 3f51d67821..a0742cbc86 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple import attr import msgpack @@ -170,12 +170,13 @@ class RoomListHandler: # we request one more than wanted to see if there are more pages to come probing_limit = limit + 1 if limit is not None else None - results = [] + results: List[PublicRoom] = [] - print(f"last_module_index {last_module_index}") - print(f"last_room_id {last_room_id}") + # print(f"{forwards} {last_joined_members} {last_room_id} {last_module_index}") - def insert_into_result(new_room: PublicRoom, module_index: Optional[int]): + def insert_into_result( + new_room: PublicRoom, module_index: Optional[int] + ) -> None: # print(f"insert {new_room.room_id} {module_index}") if new_room.num_joined_members == last_joined_members: if last_module_index is not None and last_room_id is not None: @@ -221,8 +222,6 @@ class RoomListHandler: forwards, ) - print([r.room_id for r in module_public_rooms]) - # We reverse for iteration to keep the order in the final list # since we preprend when inserting module_public_rooms.reverse() @@ -238,7 +237,7 @@ class RoomListHandler: probing_limit, bounds=( last_joined_members, - last_room_id if last_module_index == None else None, + last_room_id if last_module_index is None else None, ), forwards=forwards, ignore_non_federatable=bool(from_remote_server_name), @@ -247,22 +246,20 @@ class RoomListHandler: for r in local_public_rooms: insert_into_result(r, None) - # print("final") - # print([r.room_id for r in results]) - response: JsonDict = {} num_results = len(results) if limit is not None and probing_limit is not None: more_to_come = num_results >= probing_limit - # Depending on direction we trim either the front or back. - if forwards: - results = results[:limit] - else: - results = results[-limit:] + results = results[:limit] else: more_to_come = False + if not forwards: + results.reverse() + + # print("final ", [(r.room_id, r.num_joined_members) for r in results]) + if num_results > 0: final_entry = results[-1] initial_entry = results[0] diff --git a/synapse/module_api/callbacks/public_rooms_callbacks.py b/synapse/module_api/callbacks/public_rooms_callbacks.py index b4257bc60c..b3eeb84606 100644 --- a/synapse/module_api/callbacks/public_rooms_callbacks.py +++ b/synapse/module_api/callbacks/public_rooms_callbacks.py @@ -26,7 +26,7 @@ FETCH_PUBLIC_ROOMS_CALLBACK = Callable[ Optional[ThirdPartyInstanceID], # network_tuple Optional[dict], # search_filter Optional[int], # limit - Optional[Tuple[int, str]], # bounds + Tuple[Optional[int], Optional[str]], # bounds bool, # forwards ], Awaitable[List[PublicRoom]], diff --git a/tests/module_api/test_fetch_public_rooms.py b/tests/module_api/test_fetch_public_rooms.py index f27523d63f..0ccf68e4b2 100644 --- a/tests/module_api/test_fetch_public_rooms.py +++ b/tests/module_api/test_fetch_public_rooms.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List, Optional, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -45,7 +44,7 @@ class FetchPublicRoomsTestCase(HomeserverTestCase): self._store = homeserver.get_datastores().main self._module_api = homeserver.get_module_api() - async def cb( + async def module1_cb( network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], limit: Optional[int], @@ -73,66 +72,48 @@ class FetchPublicRoomsTestCase(HomeserverTestCase): (last_joined_members, last_room_id) = bounds - print(f"cb {forwards} {bounds}") - - result = [room1, room3, room3_2] + if forwards: + result = [room3_2, room3, room1] + else: + result = [room1, room3, room3_2] if last_joined_members is not None: - if forwards: - result = list( - filter( - lambda r: r.num_joined_members <= last_joined_members, - result, - ) - ) - else: - result = list( - filter( - lambda r: r.num_joined_members >= last_joined_members, - result, - ) - ) - - print([r.room_id for r in result]) - - if last_room_id is not None: - new_res = [] - for r in result: - if r.room_id == last_room_id: - break - new_res.append(r) - result = new_res - - if forwards: - result.reverse() + if last_joined_members == 1: + if forwards: + if last_room_id == room1.room_id: + result = [] + else: + result = [room1] + else: + if last_room_id == room1.room_id: + result = [room3, room3_2] + else: + result = [room1, room3, room3_2] + elif last_joined_members == 2: + if forwards: + result = [room1] + else: + result = [room3, room3_2] + elif last_joined_members == 3: + if forwards: + if last_room_id == room3.room_id: + result = [room1] + elif last_room_id == room3_2.room_id: + result = [room3, room1] + else: + if last_room_id == room3.room_id: + result = [room3_2] + elif last_room_id == room3_2.room_id: + result = [] + else: + result = [room3, room3_2] if limit is not None: result = result[:limit] return result - # if forwards: - # if limit == 2: - # if last_joined_members is None: - # return [room3_2, room3] - # elif last_joined_members == 3: - # if last_room_id == room3_2.room_id: - # return [room3, room1] - # if last_room_id == room3.room_id: - # return [room1] - # elif last_joined_members < 3: - # return [room1] - # return [room3_2, room3, room1] - # else: - # if ( - # limit == 2 - # and last_joined_members == 3 - # and last_room_id == room3.room_id - # ): - # return [room3_2] - # return [room1, room3, room3_2] - - async def cb2( + async def module2_cb( network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], limit: Optional[int], @@ -146,48 +127,26 @@ class FetchPublicRoomsTestCase(HomeserverTestCase): guest_can_join=False, ) - result = [room3] - (last_joined_members, last_room_id) = bounds - print(f"cb2 {forwards} {bounds}") + result = [room3] if last_joined_members is not None: if forwards: - result = list( - filter( - lambda r: r.num_joined_members <= last_joined_members, - result, - ) - ) + if last_joined_members < 3: + result = [] + elif last_joined_members == 3 and last_room_id == room3.room_id: + result = [] else: - result = list( - filter( - lambda r: r.num_joined_members >= last_joined_members, - result, - ) - ) - - print([r.room_id for r in result]) - - if last_room_id is not None: - new_res = [] - for r in result: - if r.room_id == last_room_id: - break - new_res.append(r) - result = new_res - - if forwards: - result.reverse() - - if limit is not None: - result = result[:limit] + if last_joined_members > 3: + result = [] + elif last_joined_members == 3 and last_room_id == room3.room_id: + result = [] return result - self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb2) - self._module_api.register_public_rooms_callbacks(fetch_public_rooms=cb) + self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module1_cb) + self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module2_cb) user = self.register_user("alice", "pass") token = self.login(user, "pass") @@ -227,11 +186,11 @@ class FetchPublicRoomsTestCase(HomeserverTestCase): self.assertEquals(chunk[4]["num_joined_members"], 2) self.assertEquals(chunk[5]["num_joined_members"], 1) - def test_pagination(self) -> None: + def test_pagination_limit_1(self) -> None: returned_three_members_rooms = set() next_batch = None - for i in range(4): + for _i in range(4): since_query_str = f"&since={next_batch}" if next_batch else "" channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}") chunk = channel.json_body["chunk"] @@ -250,17 +209,55 @@ class FetchPublicRoomsTestCase(HomeserverTestCase): self.assertEquals(chunk[0]["num_joined_members"], 1) prev_batch = channel.json_body["prev_batch"] - # channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") - # chunk = channel.json_body["chunk"] - # print(chunk) - # self.assertEquals(chunk[0]["num_joined_members"], 2) - # prev_batch = channel.json_body["prev_batch"] + self.assertNotIn("next_batch", channel.json_body) - # returned_three_members_rooms = set() - # for i in range(4): - # channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") - # chunk = channel.json_body["chunk"] - # self.assertEquals(chunk[0]["num_joined_members"], 3) - # self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) - # returned_three_members_rooms.add(chunk[0]["room_id"]) - # prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 2) + + returned_three_members_rooms = set() + for _i in range(4): + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + + self.assertNotIn("prev_batch", channel.json_body) + + def test_pagination_limit_2(self) -> None: + returned_three_members_rooms = set() + + next_batch = None + for _i in range(2): + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[1]["room_id"]) + next_batch = channel.json_body["next_batch"] + + channel = self.make_request("GET", f"{self.url}?limit=2&since={next_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 2) + self.assertEquals(chunk[1]["num_joined_members"], 1) + + self.assertNotIn("next_batch", channel.json_body) + + returned_three_members_rooms = set() + + for _i in range(2): + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[1]["room_id"]) + + self.assertNotIn("prev_batch", channel.json_body)