From a5986ac229b97c06be0f50fcb30d27285035ebc3 Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 8 Oct 2024 06:23:21 -0700 Subject: [PATCH] Improvements to admin redact api (#17792) - better validation on user input - fix an early task completion - when checking membership in rooms, check for rooms user has been banned from as well --- changelog.d/17792.bugfix | 1 + synapse/handlers/admin.py | 4 +- synapse/rest/admin/users.py | 45 ++++++------ synapse/storage/databases/main/roommember.py | 21 ++++++ tests/rest/admin/test_user.py | 77 ++++++++++++++++---- 5 files changed, 107 insertions(+), 41 deletions(-) create mode 100644 changelog.d/17792.bugfix diff --git a/changelog.d/17792.bugfix b/changelog.d/17792.bugfix new file mode 100644 index 0000000000..451b32782e --- /dev/null +++ b/changelog.d/17792.bugfix @@ -0,0 +1 @@ +Improve input validation and room membership checks in admin redaction API. \ No newline at end of file diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 58d89080ff..851fe57a17 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -443,8 +443,8 @@ class AdminHandler: ["m.room.member", "m.room.message"], ) if not event_ids: - # there's nothing to redact - return TaskStatus.COMPLETE, result, None + # nothing to redact in this room + continue events = await self._store.get_events_as_list(event_ids) for event in events: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 81dfb57a95..b146c2754d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr -from synapse._pydantic_compat import StrictBool +from synapse._pydantic_compat import StrictBool, StrictInt, StrictStr from synapse.api.constants import Direction, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( @@ -1421,40 +1421,39 @@ class RedactUser(RestServlet): self._store = hs.get_datastores().main self.admin_handler = hs.get_admin_handler() + class PostBody(RequestBodyModel): + rooms: List[StrictStr] + reason: Optional[StrictStr] + limit: Optional[StrictInt] + async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) await assert_user_is_admin(self._auth, requester) - body = parse_json_object_from_request(request, allow_empty_body=True) - rooms = body.get("rooms") - if rooms is None: + # parse provided user id to check that it is valid + UserID.from_string(user_id) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + + limit = body.limit + if limit and limit <= 0: raise SynapseError( - HTTPStatus.BAD_REQUEST, "Must provide a value for rooms." + HTTPStatus.BAD_REQUEST, + "If limit is provided it must be a non-negative integer greater than 0.", ) - reason = body.get("reason") - if reason: - if not isinstance(reason, str): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "If a reason is provided it must be a string.", - ) - - limit = body.get("limit") - if limit: - if not isinstance(limit, int) or limit <= 0: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "If limit is provided it must be a non-negative integer greater than 0.", - ) - + rooms = body.rooms if not rooms: - rooms = await self._store.get_rooms_for_user(user_id) + current_rooms = list(await self._store.get_rooms_for_user(user_id)) + banned_rooms = list( + await self._store.get_rooms_user_currently_banned_from(user_id) + ) + rooms = current_rooms + banned_rooms redact_id = await self.admin_handler.start_redact_events( - user_id, list(rooms), requester.serialize(), reason, limit + user_id, rooms, requester.serialize(), body.reason, limit ) return HTTPStatus.OK, {"redact_id": redact_id} diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 6f15e51339..c77e009d03 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -711,6 +711,27 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return {row[0] for row in txn} + async def get_rooms_user_currently_banned_from( + self, user_id: str + ) -> FrozenSet[str]: + """Returns a set of room_ids the user is currently banned from. + + If a remote user only returns rooms this server is currently + participating in. + """ + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.BAN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_user_currently_banned_from", + ) + + return frozenset(room_ids) + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ef918efe49..6982c7291a 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -5288,19 +5288,26 @@ class UserRedactionTestCase(unittest.HomeserverTestCase): self.assertEqual(len(matched), len(rm2_originals)) def test_admin_redact_works_if_user_kicked_or_banned(self) -> None: - originals = [] + originals1 = [] + originals2 = [] for rm in [self.rm1, self.rm2, self.rm3]: join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok) - originals.append(join["event_id"]) + if rm in [self.rm1, self.rm3]: + originals1.append(join["event_id"]) + else: + originals2.append(join["event_id"]) for i in range(5): event = {"body": f"hello{i}", "msgtype": "m.text"} res = self.helper.send_event( rm, "m.room.message", event, tok=self.bad_user_tok ) - originals.append(res["event_id"]) + if rm in [self.rm1, self.rm3]: + originals1.append(res["event_id"]) + else: + originals2.append(res["event_id"]) # kick user from rooms 1 and 3 - for r in [self.rm1, self.rm2]: + for r in [self.rm1, self.rm3]: channel = self.make_request( "POST", f"/_matrix/client/r0/rooms/{r}/kick", @@ -5330,32 +5337,70 @@ class UserRedactionTestCase(unittest.HomeserverTestCase): failed_redactions = channel2.json_body.get("failed_redactions") self.assertEqual(failed_redactions, {}) - # ban user - channel3 = self.make_request( + # double check + for rm in [self.rm1, self.rm3]: + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel3 = self.make_request( + "GET", + f"rooms/{rm}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel3.code, 200) + + matches = [] + for event in channel3.json_body["chunk"]: + for event_id in originals1: + if ( + event["type"] == "m.room.redaction" + and event["redacts"] == event_id + ): + matches.append((event_id, event)) + # we redacted 6 messages + self.assertEqual(len(matches), 6) + + # ban user from room 2 + channel4 = self.make_request( "POST", f"/_matrix/client/r0/rooms/{self.rm2}/ban", content={"reason": "being a bummer", "user_id": self.bad_user}, access_token=self.admin_tok, ) - self.assertEqual(channel3.code, HTTPStatus.OK, channel3.result) + self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result) - # redact messages in room 2 - channel4 = self.make_request( + # make a request to ban all user's messages + channel5 = self.make_request( "POST", f"/_synapse/admin/v1/user/{self.bad_user}/redact", - content={"rooms": [self.rm2]}, + content={"rooms": []}, access_token=self.admin_tok, ) - self.assertEqual(channel4.code, 200) - id2 = channel1.json_body.get("redact_id") + self.assertEqual(channel5.code, 200) + id2 = channel5.json_body.get("redact_id") # check that there were no failed redactions in room 2 - channel5 = self.make_request( + channel6 = self.make_request( "GET", f"/_synapse/admin/v1/user/redact_status/{id2}", access_token=self.admin_tok, ) - self.assertEqual(channel5.code, 200) - self.assertEqual(channel5.json_body.get("status"), "complete") - failed_redactions = channel5.json_body.get("failed_redactions") + self.assertEqual(channel6.code, 200) + self.assertEqual(channel6.json_body.get("status"), "complete") + failed_redactions = channel6.json_body.get("failed_redactions") self.assertEqual(failed_redactions, {}) + + # double check messages in room 2 were redacted + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel7 = self.make_request( + "GET", + f"rooms/{self.rm2}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel7.code, 200) + + matches = [] + for event in channel7.json_body["chunk"]: + for event_id in originals2: + if event["type"] == "m.room.redaction" and event["redacts"] == event_id: + matches.append((event_id, event)) + # we redacted 6 messages + self.assertEqual(len(matches), 6)