Port rest/admin/__init__.py to async/await

This commit is contained in:
Erik Johnston 2019-10-11 12:05:27 +01:00
parent de3a176426
commit fca3a541e7
1 changed files with 55 additions and 72 deletions

View File

@ -23,8 +23,6 @@ import re
from six import text_type from six import text_type
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer
import synapse import synapse
from synapse.api.constants import Membership, UserTypes from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -59,15 +57,14 @@ class UsersRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
ret = yield self.handlers.admin_handler.get_users() ret = await self.handlers.admin_handler.get_users()
return 200, ret return 200, ret
@ -122,8 +119,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds()) self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce} return 200, {"nonce": nonce}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
self._clear_old_nonces() self._clear_old_nonces()
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
@ -204,14 +200,14 @@ class UserRegisterServlet(RestServlet):
register = RegisterRestServlet(self.hs) register = RegisterRestServlet(self.hs)
user_id = yield register.registration_handler.register_user( user_id = await register.registration_handler.register_user(
localpart=body["username"].lower(), localpart=body["username"].lower(),
password=body["password"], password=body["password"],
admin=bool(admin), admin=bool(admin),
user_type=user_type, user_type=user_type,
) )
result = yield register._create_registration_details(user_id, body) result = await register._create_registration_details(user_id, body)
return 200, result return 200, result
@ -223,19 +219,18 @@ class WhoisRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
auth_user = requester.user auth_user = requester.user
if target_user != auth_user: if target_user != auth_user:
yield assert_user_is_admin(self.auth, auth_user) await assert_user_is_admin(self.auth, auth_user)
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user") raise SynapseError(400, "Can only whois a local user")
ret = yield self.handlers.admin_handler.get_whois(target_user) ret = await self.handlers.admin_handler.get_whois(target_user)
return 200, ret return 200, ret
@ -255,9 +250,8 @@ class PurgeHistoryRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, room_id, event_id):
def on_POST(self, request, room_id, event_id): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True) body = parse_json_object_from_request(request, allow_empty_body=True)
@ -270,12 +264,12 @@ class PurgeHistoryRestServlet(RestServlet):
event_id = body.get("purge_up_to_event_id") event_id = body.get("purge_up_to_event_id")
if event_id is not None: if event_id is not None:
event = yield self.store.get_event(event_id) event = await self.store.get_event(event_id)
if event.room_id != room_id: if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.") raise SynapseError(400, "Event is for wrong room.")
token = yield self.store.get_topological_token_for_event(event_id) token = await self.store.get_topological_token_for_event(event_id)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body: elif "purge_up_to_ts" in body:
@ -285,12 +279,10 @@ class PurgeHistoryRestServlet(RestServlet):
400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
) )
stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = ( r = await self.store.get_room_event_after_stream_ordering(
yield self.store.get_room_event_after_stream_ordering( room_id, stream_ordering
room_id, stream_ordering
)
) )
if not r: if not r:
logger.warn( logger.warn(
@ -318,7 +310,7 @@ class PurgeHistoryRestServlet(RestServlet):
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
purge_id = yield self.pagination_handler.start_purge_history( purge_id = await self.pagination_handler.start_purge_history(
room_id, token, delete_local_events=delete_local_events room_id, token, delete_local_events=delete_local_events
) )
@ -339,9 +331,8 @@ class PurgeHistoryStatusRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, purge_id):
def on_GET(self, request, purge_id): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
purge_status = self.pagination_handler.get_purge_status(purge_id) purge_status = self.pagination_handler.get_purge_status(purge_id)
if purge_status is None: if purge_status is None:
@ -357,9 +348,8 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, target_user_id):
def on_POST(self, request, target_user_id): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True) body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False) erase = body.get("erase", False)
if not isinstance(erase, bool): if not isinstance(erase, bool):
@ -371,7 +361,7 @@ class DeactivateAccountRestServlet(RestServlet):
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
result = yield self._deactivate_account_handler.deactivate_account( result = await self._deactivate_account_handler.deactivate_account(
target_user_id, erase target_user_id, erase
) )
if result: if result:
@ -405,10 +395,9 @@ class ShutdownRoomRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, room_id):
def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user)
yield assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"]) assert_params_in_dict(content, ["new_room_user_id"])
@ -419,7 +408,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE) message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification") room_name = content.get("room_name", "Content Violation Notification")
info = yield self._room_creation_handler.create_room( info = await self._room_creation_handler.create_room(
room_creator_requester, room_creator_requester,
config={ config={
"preset": "public_chat", "preset": "public_chat",
@ -438,9 +427,9 @@ class ShutdownRoomRestServlet(RestServlet):
# This will work even if the room is already blocked, but that is # This will work even if the room is already blocked, but that is
# desirable in case the first attempt at blocking the room failed below. # desirable in case the first attempt at blocking the room failed below.
yield self.store.block_room(room_id, requester_user_id) await self.store.block_room(room_id, requester_user_id)
users = yield self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
kicked_users = [] kicked_users = []
failed_to_kick_users = [] failed_to_kick_users = []
for user_id in users: for user_id in users:
@ -451,7 +440,7 @@ class ShutdownRoomRestServlet(RestServlet):
try: try:
target_requester = create_requester(user_id) target_requester = create_requester(user_id)
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=room_id, room_id=room_id,
@ -461,9 +450,9 @@ class ShutdownRoomRestServlet(RestServlet):
require_consent=False, require_consent=False,
) )
yield self.room_member_handler.forget(target_requester.user, room_id) await self.room_member_handler.forget(target_requester.user, room_id)
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=new_room_id, room_id=new_room_id,
@ -480,7 +469,7 @@ class ShutdownRoomRestServlet(RestServlet):
) )
failed_to_kick_users.append(user_id) failed_to_kick_users.append(user_id)
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester, room_creator_requester,
{ {
"type": "m.room.message", "type": "m.room.message",
@ -491,9 +480,9 @@ class ShutdownRoomRestServlet(RestServlet):
ratelimit=False, ratelimit=False,
) )
aliases_for_room = yield self.store.get_aliases_for_room(room_id) aliases_for_room = await self.store.get_aliases_for_room(room_id)
yield self.store.update_aliases_for_room( await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id room_id, new_room_id, requester_user_id
) )
@ -532,13 +521,12 @@ class ResetPasswordRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks async def on_POST(self, request, target_user_id):
def on_POST(self, request, target_user_id):
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
yield assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
@ -546,7 +534,7 @@ class ResetPasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"]) assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"] new_password = params["new_password"]
yield self._set_password_handler.set_password( await self._set_password_handler.set_password(
target_user_id, new_password, requester target_user_id, new_password, requester
) )
return 200, {} return 200, {}
@ -572,12 +560,11 @@ class GetUsersPaginatedRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, target_user_id):
def on_GET(self, request, target_user_id):
"""Get request to get specific number of users from Synapse. """Get request to get specific number of users from Synapse.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
@ -590,11 +577,10 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_POST(self, request, target_user_id):
def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse.. """Post request to get specific number of users from Synapse..
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
Example: Example:
@ -608,7 +594,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
Returns: Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object. 200 OK with json object {list[dict[str, Any]], count} or empty object.
""" """
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
order = "name" # order by name in user table order = "name" # order by name in user table
@ -618,7 +604,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
start = params["start"] start = params["start"]
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret return 200, ret
@ -641,13 +627,12 @@ class SearchUsersRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, target_user_id):
def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to """Get request to search user table for specific users according to
search term. search term.
This needs user to have a administrator access in Synapse. This needs user to have a administrator access in Synapse.
""" """
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
@ -661,7 +646,7 @@ class SearchUsersRestServlet(RestServlet):
term = parse_string(request, "term", required=True) term = parse_string(request, "term", required=True)
logger.info("term: %s ", term) logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(term) ret = await self.handlers.admin_handler.search_users(term)
return 200, ret return 200, ret
@ -676,15 +661,14 @@ class DeleteGroupAdminRestServlet(RestServlet):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, group_id):
def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user)
yield assert_user_is_admin(self.auth, requester.user)
if not self.is_mine_id(group_id): if not self.is_mine_id(group_id):
raise SynapseError(400, "Can only delete local groups") raise SynapseError(400, "Can only delete local groups")
yield self.group_server.delete_group(group_id, requester.user.to_string()) await self.group_server.delete_group(group_id, requester.user.to_string())
return 200, {} return 200, {}
@ -700,16 +684,15 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if "user_id" not in body: if "user_id" not in body:
raise SynapseError(400, "Missing property 'user_id' in the request body") raise SynapseError(400, "Missing property 'user_id' in the request body")
expiration_ts = yield self.account_activity_handler.renew_account_for_user( expiration_ts = await self.account_activity_handler.renew_account_for_user(
body["user_id"], body["user_id"],
body.get("expiration_ts"), body.get("expiration_ts"),
not body.get("enable_renewal_emails", True), not body.get("enable_renewal_emails", True),