Improvements to admin redact api (#17792)
- better validation on user input - fix an early task completion - when checking membership in rooms, check for rooms user has been banned from as well
This commit is contained in:
parent
006251a5d0
commit
a5986ac229
|
@ -0,0 +1 @@
|
|||
Improve input validation and room membership checks in admin redaction API.
|
|
@ -443,8 +443,8 @@ class AdminHandler:
|
|||
["m.room.member", "m.room.message"],
|
||||
)
|
||||
if not event_ids:
|
||||
# there's nothing to redact
|
||||
return TaskStatus.COMPLETE, result, None
|
||||
# nothing to redact in this room
|
||||
continue
|
||||
|
||||
events = await self._store.get_events_as_list(event_ids)
|
||||
for event in events:
|
||||
|
|
|
@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|||
|
||||
import attr
|
||||
|
||||
from synapse._pydantic_compat import StrictBool
|
||||
from synapse._pydantic_compat import StrictBool, StrictInt, StrictStr
|
||||
from synapse.api.constants import Direction, UserTypes
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
|
@ -1421,40 +1421,39 @@ class RedactUser(RestServlet):
|
|||
self._store = hs.get_datastores().main
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
rooms: List[StrictStr]
|
||||
reason: Optional[StrictStr]
|
||||
limit: Optional[StrictInt]
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester)
|
||||
|
||||
body = parse_json_object_from_request(request, allow_empty_body=True)
|
||||
rooms = body.get("rooms")
|
||||
if rooms is None:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Must provide a value for rooms."
|
||||
)
|
||||
# parse provided user id to check that it is valid
|
||||
UserID.from_string(user_id)
|
||||
|
||||
reason = body.get("reason")
|
||||
if reason:
|
||||
if not isinstance(reason, str):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"If a reason is provided it must be a string.",
|
||||
)
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
|
||||
limit = body.get("limit")
|
||||
if limit:
|
||||
if not isinstance(limit, int) or limit <= 0:
|
||||
limit = body.limit
|
||||
if limit and limit <= 0:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"If limit is provided it must be a non-negative integer greater than 0.",
|
||||
)
|
||||
|
||||
rooms = body.rooms
|
||||
if not rooms:
|
||||
rooms = await self._store.get_rooms_for_user(user_id)
|
||||
current_rooms = list(await self._store.get_rooms_for_user(user_id))
|
||||
banned_rooms = list(
|
||||
await self._store.get_rooms_user_currently_banned_from(user_id)
|
||||
)
|
||||
rooms = current_rooms + banned_rooms
|
||||
|
||||
redact_id = await self.admin_handler.start_redact_events(
|
||||
user_id, list(rooms), requester.serialize(), reason, limit
|
||||
user_id, rooms, requester.serialize(), body.reason, limit
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {"redact_id": redact_id}
|
||||
|
|
|
@ -711,6 +711,27 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
|||
|
||||
return {row[0] for row in txn}
|
||||
|
||||
async def get_rooms_user_currently_banned_from(
|
||||
self, user_id: str
|
||||
) -> FrozenSet[str]:
|
||||
"""Returns a set of room_ids the user is currently banned from.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
participating in.
|
||||
"""
|
||||
room_ids = await self.db_pool.simple_select_onecol(
|
||||
table="current_state_events",
|
||||
keyvalues={
|
||||
"type": EventTypes.Member,
|
||||
"membership": Membership.BAN,
|
||||
"state_key": user_id,
|
||||
},
|
||||
retcol="room_id",
|
||||
desc="get_rooms_user_currently_banned_from",
|
||||
)
|
||||
|
||||
return frozenset(room_ids)
|
||||
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
|
|
@ -5288,19 +5288,26 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(matched), len(rm2_originals))
|
||||
|
||||
def test_admin_redact_works_if_user_kicked_or_banned(self) -> None:
|
||||
originals = []
|
||||
originals1 = []
|
||||
originals2 = []
|
||||
for rm in [self.rm1, self.rm2, self.rm3]:
|
||||
join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
|
||||
originals.append(join["event_id"])
|
||||
if rm in [self.rm1, self.rm3]:
|
||||
originals1.append(join["event_id"])
|
||||
else:
|
||||
originals2.append(join["event_id"])
|
||||
for i in range(5):
|
||||
event = {"body": f"hello{i}", "msgtype": "m.text"}
|
||||
res = self.helper.send_event(
|
||||
rm, "m.room.message", event, tok=self.bad_user_tok
|
||||
)
|
||||
originals.append(res["event_id"])
|
||||
if rm in [self.rm1, self.rm3]:
|
||||
originals1.append(res["event_id"])
|
||||
else:
|
||||
originals2.append(res["event_id"])
|
||||
|
||||
# kick user from rooms 1 and 3
|
||||
for r in [self.rm1, self.rm2]:
|
||||
for r in [self.rm1, self.rm3]:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/r0/rooms/{r}/kick",
|
||||
|
@ -5330,32 +5337,70 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
|
|||
failed_redactions = channel2.json_body.get("failed_redactions")
|
||||
self.assertEqual(failed_redactions, {})
|
||||
|
||||
# ban user
|
||||
# double check
|
||||
for rm in [self.rm1, self.rm3]:
|
||||
filter = json.dumps({"types": [EventTypes.Redaction]})
|
||||
channel3 = self.make_request(
|
||||
"GET",
|
||||
f"rooms/{rm}/messages?filter={filter}&limit=50",
|
||||
access_token=self.admin_tok,
|
||||
)
|
||||
self.assertEqual(channel3.code, 200)
|
||||
|
||||
matches = []
|
||||
for event in channel3.json_body["chunk"]:
|
||||
for event_id in originals1:
|
||||
if (
|
||||
event["type"] == "m.room.redaction"
|
||||
and event["redacts"] == event_id
|
||||
):
|
||||
matches.append((event_id, event))
|
||||
# we redacted 6 messages
|
||||
self.assertEqual(len(matches), 6)
|
||||
|
||||
# ban user from room 2
|
||||
channel4 = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/r0/rooms/{self.rm2}/ban",
|
||||
content={"reason": "being a bummer", "user_id": self.bad_user},
|
||||
access_token=self.admin_tok,
|
||||
)
|
||||
self.assertEqual(channel3.code, HTTPStatus.OK, channel3.result)
|
||||
self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result)
|
||||
|
||||
# redact messages in room 2
|
||||
channel4 = self.make_request(
|
||||
# make a request to ban all user's messages
|
||||
channel5 = self.make_request(
|
||||
"POST",
|
||||
f"/_synapse/admin/v1/user/{self.bad_user}/redact",
|
||||
content={"rooms": [self.rm2]},
|
||||
content={"rooms": []},
|
||||
access_token=self.admin_tok,
|
||||
)
|
||||
self.assertEqual(channel4.code, 200)
|
||||
id2 = channel1.json_body.get("redact_id")
|
||||
self.assertEqual(channel5.code, 200)
|
||||
id2 = channel5.json_body.get("redact_id")
|
||||
|
||||
# check that there were no failed redactions in room 2
|
||||
channel5 = self.make_request(
|
||||
channel6 = self.make_request(
|
||||
"GET",
|
||||
f"/_synapse/admin/v1/user/redact_status/{id2}",
|
||||
access_token=self.admin_tok,
|
||||
)
|
||||
self.assertEqual(channel5.code, 200)
|
||||
self.assertEqual(channel5.json_body.get("status"), "complete")
|
||||
failed_redactions = channel5.json_body.get("failed_redactions")
|
||||
self.assertEqual(channel6.code, 200)
|
||||
self.assertEqual(channel6.json_body.get("status"), "complete")
|
||||
failed_redactions = channel6.json_body.get("failed_redactions")
|
||||
self.assertEqual(failed_redactions, {})
|
||||
|
||||
# double check messages in room 2 were redacted
|
||||
filter = json.dumps({"types": [EventTypes.Redaction]})
|
||||
channel7 = self.make_request(
|
||||
"GET",
|
||||
f"rooms/{self.rm2}/messages?filter={filter}&limit=50",
|
||||
access_token=self.admin_tok,
|
||||
)
|
||||
self.assertEqual(channel7.code, 200)
|
||||
|
||||
matches = []
|
||||
for event in channel7.json_body["chunk"]:
|
||||
for event_id in originals2:
|
||||
if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
|
||||
matches.append((event_id, event))
|
||||
# we redacted 6 messages
|
||||
self.assertEqual(len(matches), 6)
|
||||
|
|
Loading…
Reference in New Issue