Improvements to admin redact api (#17792)
- better validation on user input - fix an early task completion - when checking membership in rooms, check for rooms user has been banned from as well
This commit is contained in:
parent
006251a5d0
commit
a5986ac229
|
@ -0,0 +1 @@
|
||||||
|
Improve input validation and room membership checks in admin redaction API.
|
|
@ -443,8 +443,8 @@ class AdminHandler:
|
||||||
["m.room.member", "m.room.message"],
|
["m.room.member", "m.room.message"],
|
||||||
)
|
)
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
# there's nothing to redact
|
# nothing to redact in this room
|
||||||
return TaskStatus.COMPLETE, result, None
|
continue
|
||||||
|
|
||||||
events = await self._store.get_events_as_list(event_ids)
|
events = await self._store.get_events_as_list(event_ids)
|
||||||
for event in events:
|
for event in events:
|
||||||
|
|
|
@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse._pydantic_compat import StrictBool
|
from synapse._pydantic_compat import StrictBool, StrictInt, StrictStr
|
||||||
from synapse.api.constants import Direction, UserTypes
|
from synapse.api.constants import Direction, UserTypes
|
||||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
|
@ -1421,40 +1421,39 @@ class RedactUser(RestServlet):
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self.admin_handler = hs.get_admin_handler()
|
self.admin_handler = hs.get_admin_handler()
|
||||||
|
|
||||||
|
class PostBody(RequestBodyModel):
|
||||||
|
rooms: List[StrictStr]
|
||||||
|
reason: Optional[StrictStr]
|
||||||
|
limit: Optional[StrictInt]
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, request: SynapseRequest, user_id: str
|
self, request: SynapseRequest, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self._auth.get_user_by_req(request)
|
requester = await self._auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self._auth, requester)
|
await assert_user_is_admin(self._auth, requester)
|
||||||
|
|
||||||
body = parse_json_object_from_request(request, allow_empty_body=True)
|
# parse provided user id to check that it is valid
|
||||||
rooms = body.get("rooms")
|
UserID.from_string(user_id)
|
||||||
if rooms is None:
|
|
||||||
raise SynapseError(
|
|
||||||
HTTPStatus.BAD_REQUEST, "Must provide a value for rooms."
|
|
||||||
)
|
|
||||||
|
|
||||||
reason = body.get("reason")
|
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||||
if reason:
|
|
||||||
if not isinstance(reason, str):
|
|
||||||
raise SynapseError(
|
|
||||||
HTTPStatus.BAD_REQUEST,
|
|
||||||
"If a reason is provided it must be a string.",
|
|
||||||
)
|
|
||||||
|
|
||||||
limit = body.get("limit")
|
limit = body.limit
|
||||||
if limit:
|
if limit and limit <= 0:
|
||||||
if not isinstance(limit, int) or limit <= 0:
|
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"If limit is provided it must be a non-negative integer greater than 0.",
|
"If limit is provided it must be a non-negative integer greater than 0.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rooms = body.rooms
|
||||||
if not rooms:
|
if not rooms:
|
||||||
rooms = await self._store.get_rooms_for_user(user_id)
|
current_rooms = list(await self._store.get_rooms_for_user(user_id))
|
||||||
|
banned_rooms = list(
|
||||||
|
await self._store.get_rooms_user_currently_banned_from(user_id)
|
||||||
|
)
|
||||||
|
rooms = current_rooms + banned_rooms
|
||||||
|
|
||||||
redact_id = await self.admin_handler.start_redact_events(
|
redact_id = await self.admin_handler.start_redact_events(
|
||||||
user_id, list(rooms), requester.serialize(), reason, limit
|
user_id, rooms, requester.serialize(), body.reason, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
return HTTPStatus.OK, {"redact_id": redact_id}
|
return HTTPStatus.OK, {"redact_id": redact_id}
|
||||||
|
|
|
@ -711,6 +711,27 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
return {row[0] for row in txn}
|
return {row[0] for row in txn}
|
||||||
|
|
||||||
|
async def get_rooms_user_currently_banned_from(
|
||||||
|
self, user_id: str
|
||||||
|
) -> FrozenSet[str]:
|
||||||
|
"""Returns a set of room_ids the user is currently banned from.
|
||||||
|
|
||||||
|
If a remote user only returns rooms this server is currently
|
||||||
|
participating in.
|
||||||
|
"""
|
||||||
|
room_ids = await self.db_pool.simple_select_onecol(
|
||||||
|
table="current_state_events",
|
||||||
|
keyvalues={
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"membership": Membership.BAN,
|
||||||
|
"state_key": user_id,
|
||||||
|
},
|
||||||
|
retcol="room_id",
|
||||||
|
desc="get_rooms_user_currently_banned_from",
|
||||||
|
)
|
||||||
|
|
||||||
|
return frozenset(room_ids)
|
||||||
|
|
||||||
@cached(max_entries=500000, iterable=True)
|
@cached(max_entries=500000, iterable=True)
|
||||||
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
|
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
|
||||||
"""Returns a set of room_ids the user is currently joined to.
|
"""Returns a set of room_ids the user is currently joined to.
|
||||||
|
|
|
@ -5288,19 +5288,26 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(matched), len(rm2_originals))
|
self.assertEqual(len(matched), len(rm2_originals))
|
||||||
|
|
||||||
def test_admin_redact_works_if_user_kicked_or_banned(self) -> None:
|
def test_admin_redact_works_if_user_kicked_or_banned(self) -> None:
|
||||||
originals = []
|
originals1 = []
|
||||||
|
originals2 = []
|
||||||
for rm in [self.rm1, self.rm2, self.rm3]:
|
for rm in [self.rm1, self.rm2, self.rm3]:
|
||||||
join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
|
join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
|
||||||
originals.append(join["event_id"])
|
if rm in [self.rm1, self.rm3]:
|
||||||
|
originals1.append(join["event_id"])
|
||||||
|
else:
|
||||||
|
originals2.append(join["event_id"])
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
event = {"body": f"hello{i}", "msgtype": "m.text"}
|
event = {"body": f"hello{i}", "msgtype": "m.text"}
|
||||||
res = self.helper.send_event(
|
res = self.helper.send_event(
|
||||||
rm, "m.room.message", event, tok=self.bad_user_tok
|
rm, "m.room.message", event, tok=self.bad_user_tok
|
||||||
)
|
)
|
||||||
originals.append(res["event_id"])
|
if rm in [self.rm1, self.rm3]:
|
||||||
|
originals1.append(res["event_id"])
|
||||||
|
else:
|
||||||
|
originals2.append(res["event_id"])
|
||||||
|
|
||||||
# kick user from rooms 1 and 3
|
# kick user from rooms 1 and 3
|
||||||
for r in [self.rm1, self.rm2]:
|
for r in [self.rm1, self.rm3]:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/_matrix/client/r0/rooms/{r}/kick",
|
f"/_matrix/client/r0/rooms/{r}/kick",
|
||||||
|
@ -5330,32 +5337,70 @@ class UserRedactionTestCase(unittest.HomeserverTestCase):
|
||||||
failed_redactions = channel2.json_body.get("failed_redactions")
|
failed_redactions = channel2.json_body.get("failed_redactions")
|
||||||
self.assertEqual(failed_redactions, {})
|
self.assertEqual(failed_redactions, {})
|
||||||
|
|
||||||
# ban user
|
# double check
|
||||||
|
for rm in [self.rm1, self.rm3]:
|
||||||
|
filter = json.dumps({"types": [EventTypes.Redaction]})
|
||||||
channel3 = self.make_request(
|
channel3 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"rooms/{rm}/messages?filter={filter}&limit=50",
|
||||||
|
access_token=self.admin_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel3.code, 200)
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for event in channel3.json_body["chunk"]:
|
||||||
|
for event_id in originals1:
|
||||||
|
if (
|
||||||
|
event["type"] == "m.room.redaction"
|
||||||
|
and event["redacts"] == event_id
|
||||||
|
):
|
||||||
|
matches.append((event_id, event))
|
||||||
|
# we redacted 6 messages
|
||||||
|
self.assertEqual(len(matches), 6)
|
||||||
|
|
||||||
|
# ban user from room 2
|
||||||
|
channel4 = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/_matrix/client/r0/rooms/{self.rm2}/ban",
|
f"/_matrix/client/r0/rooms/{self.rm2}/ban",
|
||||||
content={"reason": "being a bummer", "user_id": self.bad_user},
|
content={"reason": "being a bummer", "user_id": self.bad_user},
|
||||||
access_token=self.admin_tok,
|
access_token=self.admin_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel3.code, HTTPStatus.OK, channel3.result)
|
self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result)
|
||||||
|
|
||||||
# redact messages in room 2
|
# make a request to ban all user's messages
|
||||||
channel4 = self.make_request(
|
channel5 = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/_synapse/admin/v1/user/{self.bad_user}/redact",
|
f"/_synapse/admin/v1/user/{self.bad_user}/redact",
|
||||||
content={"rooms": [self.rm2]},
|
content={"rooms": []},
|
||||||
access_token=self.admin_tok,
|
access_token=self.admin_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel4.code, 200)
|
self.assertEqual(channel5.code, 200)
|
||||||
id2 = channel1.json_body.get("redact_id")
|
id2 = channel5.json_body.get("redact_id")
|
||||||
|
|
||||||
# check that there were no failed redactions in room 2
|
# check that there were no failed redactions in room 2
|
||||||
channel5 = self.make_request(
|
channel6 = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/_synapse/admin/v1/user/redact_status/{id2}",
|
f"/_synapse/admin/v1/user/redact_status/{id2}",
|
||||||
access_token=self.admin_tok,
|
access_token=self.admin_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel5.code, 200)
|
self.assertEqual(channel6.code, 200)
|
||||||
self.assertEqual(channel5.json_body.get("status"), "complete")
|
self.assertEqual(channel6.json_body.get("status"), "complete")
|
||||||
failed_redactions = channel5.json_body.get("failed_redactions")
|
failed_redactions = channel6.json_body.get("failed_redactions")
|
||||||
self.assertEqual(failed_redactions, {})
|
self.assertEqual(failed_redactions, {})
|
||||||
|
|
||||||
|
# double check messages in room 2 were redacted
|
||||||
|
filter = json.dumps({"types": [EventTypes.Redaction]})
|
||||||
|
channel7 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"rooms/{self.rm2}/messages?filter={filter}&limit=50",
|
||||||
|
access_token=self.admin_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel7.code, 200)
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for event in channel7.json_body["chunk"]:
|
||||||
|
for event_id in originals2:
|
||||||
|
if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
|
||||||
|
matches.append((event_id, event))
|
||||||
|
# we redacted 6 messages
|
||||||
|
self.assertEqual(len(matches), 6)
|
||||||
|
|
Loading…
Reference in New Issue