From fca3a541e7e5845d61c519be7223a035374ed698 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Oct 2019 12:05:27 +0100 Subject: [PATCH] Port rest/admin/__init__.py to async/await --- synapse/rest/admin/__init__.py | 127 ++++++++++++++------------------- 1 file changed, 55 insertions(+), 72 deletions(-) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 81b6bd8816..f7b9483008 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -23,8 +23,6 @@ import re from six import text_type from six.moves import http_client -from twisted.internet import defer - import synapse from synapse.api.constants import Membership, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -59,15 +57,14 @@ class UsersRestServlet(RestServlet): self.auth = hs.get_auth() self.handlers = hs.get_handlers() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - yield assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only users a local user") - ret = yield self.handlers.admin_handler.get_users() + ret = await self.handlers.admin_handler.get_users() return 200, ret @@ -122,8 +119,7 @@ class UserRegisterServlet(RestServlet): self.nonces[nonce] = int(self.reactor.seconds()) return 200, {"nonce": nonce} - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): self._clear_old_nonces() if not self.hs.config.registration_shared_secret: @@ -204,14 +200,14 @@ class UserRegisterServlet(RestServlet): register = RegisterRestServlet(self.hs) - user_id = yield register.registration_handler.register_user( + user_id = await register.registration_handler.register_user( localpart=body["username"].lower(), password=body["password"], admin=bool(admin), user_type=user_type, ) - result = yield register._create_registration_details(user_id, body) + result = await register._create_registration_details(user_id, body) return 200, result @@ -223,19 +219,18 @@ class WhoisRestServlet(RestServlet): self.auth = hs.get_auth() self.handlers = hs.get_handlers() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) auth_user = requester.user if target_user != auth_user: - yield assert_user_is_admin(self.auth, auth_user) + await assert_user_is_admin(self.auth, auth_user) if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only whois a local user") - ret = yield self.handlers.admin_handler.get_whois(target_user) + ret = await self.handlers.admin_handler.get_whois(target_user) return 200, ret @@ -255,9 +250,8 @@ class PurgeHistoryRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - yield assert_requester_is_admin(self.auth, request) + async def on_POST(self, request, room_id, event_id): + await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) @@ -270,12 +264,12 @@ class PurgeHistoryRestServlet(RestServlet): event_id = body.get("purge_up_to_event_id") if event_id is not None: - event = yield self.store.get_event(event_id) + event = await self.store.get_event(event_id) if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - token = yield self.store.get_topological_token_for_event(event_id) + token = await self.store.get_topological_token_for_event(event_id) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: @@ -285,12 +279,10 @@ class PurgeHistoryRestServlet(RestServlet): 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON ) - stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) + stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) - r = ( - yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering - ) + r = await self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering ) if not r: logger.warn( @@ -318,7 +310,7 @@ class PurgeHistoryRestServlet(RestServlet): errcode=Codes.BAD_JSON, ) - purge_id = yield self.pagination_handler.start_purge_history( + purge_id = await self.pagination_handler.start_purge_history( room_id, token, delete_local_events=delete_local_events ) @@ -339,9 +331,8 @@ class PurgeHistoryStatusRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, purge_id): - yield assert_requester_is_admin(self.auth, request) + async def on_GET(self, request, purge_id): + await assert_requester_is_admin(self.auth, request) purge_status = self.pagination_handler.get_purge_status(purge_id) if purge_status is None: @@ -357,9 +348,8 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, target_user_id): - yield assert_requester_is_admin(self.auth, request) + async def on_POST(self, request, target_user_id): + await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -371,7 +361,7 @@ class DeactivateAccountRestServlet(RestServlet): UserID.from_string(target_user_id) - result = yield self._deactivate_account_handler.deactivate_account( + result = await self._deactivate_account_handler.deactivate_account( target_user_id, erase ) if result: @@ -405,10 +395,9 @@ class ShutdownRoomRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) content = parse_json_object_from_request(request) assert_params_in_dict(content, ["new_room_user_id"]) @@ -419,7 +408,7 @@ class ShutdownRoomRestServlet(RestServlet): message = content.get("message", self.DEFAULT_MESSAGE) room_name = content.get("room_name", "Content Violation Notification") - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": "public_chat", @@ -438,9 +427,9 @@ class ShutdownRoomRestServlet(RestServlet): # This will work even if the room is already blocked, but that is # desirable in case the first attempt at blocking the room failed below. - yield self.store.block_room(room_id, requester_user_id) + await self.store.block_room(room_id, requester_user_id) - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] for user_id in users: @@ -451,7 +440,7 @@ class ShutdownRoomRestServlet(RestServlet): try: target_requester = create_requester(user_id) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=room_id, @@ -461,9 +450,9 @@ class ShutdownRoomRestServlet(RestServlet): require_consent=False, ) - yield self.room_member_handler.forget(target_requester.user, room_id) + await self.room_member_handler.forget(target_requester.user, room_id) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=new_room_id, @@ -480,7 +469,7 @@ class ShutdownRoomRestServlet(RestServlet): ) failed_to_kick_users.append(user_id) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( room_creator_requester, { "type": "m.room.message", @@ -491,9 +480,9 @@ class ShutdownRoomRestServlet(RestServlet): ratelimit=False, ) - aliases_for_room = yield self.store.get_aliases_for_room(room_id) + aliases_for_room = await self.store.get_aliases_for_room(room_id) - yield self.store.update_aliases_for_room( + await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id ) @@ -532,13 +521,12 @@ class ResetPasswordRestServlet(RestServlet): self.auth = hs.get_auth() self._set_password_handler = hs.get_set_password_handler() - @defer.inlineCallbacks - def on_POST(self, request, target_user_id): + async def on_POST(self, request, target_user_id): """Post request to allow an administrator reset password for a user. This needs user to have administrator access in Synapse. """ - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) UserID.from_string(target_user_id) @@ -546,7 +534,7 @@ class ResetPasswordRestServlet(RestServlet): assert_params_in_dict(params, ["new_password"]) new_password = params["new_password"] - yield self._set_password_handler.set_password( + await self._set_password_handler.set_password( target_user_id, new_password, requester ) return 200, {} @@ -572,12 +560,11 @@ class GetUsersPaginatedRestServlet(RestServlet): self.auth = hs.get_auth() self.handlers = hs.get_handlers() - @defer.inlineCallbacks - def on_GET(self, request, target_user_id): + async def on_GET(self, request, target_user_id): """Get request to get specific number of users from Synapse. This needs user to have administrator access in Synapse. """ - yield assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(target_user_id) @@ -590,11 +577,10 @@ class GetUsersPaginatedRestServlet(RestServlet): logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) + ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) return 200, ret - @defer.inlineCallbacks - def on_POST(self, request, target_user_id): + async def on_POST(self, request, target_user_id): """Post request to get specific number of users from Synapse.. This needs user to have administrator access in Synapse. Example: @@ -608,7 +594,7 @@ class GetUsersPaginatedRestServlet(RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - yield assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self.auth, request) UserID.from_string(target_user_id) order = "name" # order by name in user table @@ -618,7 +604,7 @@ class GetUsersPaginatedRestServlet(RestServlet): start = params["start"] logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) + ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) return 200, ret @@ -641,13 +627,12 @@ class SearchUsersRestServlet(RestServlet): self.auth = hs.get_auth() self.handlers = hs.get_handlers() - @defer.inlineCallbacks - def on_GET(self, request, target_user_id): + async def on_GET(self, request, target_user_id): """Get request to search user table for specific users according to search term. This needs user to have a administrator access in Synapse. """ - yield assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(target_user_id) @@ -661,7 +646,7 @@ class SearchUsersRestServlet(RestServlet): term = parse_string(request, "term", required=True) logger.info("term: %s ", term) - ret = yield self.handlers.admin_handler.search_users(term) + ret = await self.handlers.admin_handler.search_users(term) return 200, ret @@ -676,15 +661,14 @@ class DeleteGroupAdminRestServlet(RestServlet): self.is_mine_id = hs.is_mine_id self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) + async def on_POST(self, request, group_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) if not self.is_mine_id(group_id): raise SynapseError(400, "Can only delete local groups") - yield self.group_server.delete_group(group_id, requester.user.to_string()) + await self.group_server.delete_group(group_id, requester.user.to_string()) return 200, {} @@ -700,16 +684,15 @@ class AccountValidityRenewServlet(RestServlet): self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - yield assert_requester_is_admin(self.auth, request) + async def on_POST(self, request): + await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) if "user_id" not in body: raise SynapseError(400, "Missing property 'user_id' in the request body") - expiration_ts = yield self.account_activity_handler.renew_account_for_user( + expiration_ts = await self.account_activity_handler.renew_account_for_user( body["user_id"], body.get("expiration_ts"), not body.get("enable_renewal_emails", True),