Merge pull request #6196 from matrix-org/erikj/await

Move rest/admin to use async/await.
This commit is contained in:
Erik Johnston 2019-10-18 11:53:02 +02:00 committed by GitHub
commit d98029ea89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 114 additions and 114 deletions

1
changelog.d/6196.misc Normal file
View File

@ -0,0 +1 @@
Port synapse.rest.admin module to use async/await.

View File

@ -23,8 +23,6 @@ import re
from six import text_type from six import text_type
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer
import synapse import synapse
from synapse.api.constants import Membership, UserTypes from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -46,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import UserAdminServlet from synapse.rest.admin.users import UserAdminServlet
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,15 +58,14 @@ class UsersRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
ret = yield self.handlers.admin_handler.get_users() ret = await self.handlers.admin_handler.get_users()
return 200, ret return 200, ret
@ -122,8 +120,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds()) self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce} return 200, {"nonce": nonce}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
self._clear_old_nonces() self._clear_old_nonces()
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
@ -204,14 +201,14 @@ class UserRegisterServlet(RestServlet):
register = RegisterRestServlet(self.hs) register = RegisterRestServlet(self.hs)
user_id = yield register.registration_handler.register_user( user_id = await register.registration_handler.register_user(
localpart=body["username"].lower(), localpart=body["username"].lower(),
password=body["password"], password=body["password"],
admin=bool(admin), admin=bool(admin),
user_type=user_type, user_type=user_type,
) )
result = yield register._create_registration_details(user_id, body) result = await register._create_registration_details(user_id, body)
return 200, result return 200, result
@ -223,19 +220,18 @@ class WhoisRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
auth_user = requester.user auth_user = requester.user
if target_user != auth_user: if target_user != auth_user:
yield assert_user_is_admin(self.auth, auth_user) await assert_user_is_admin(self.auth, auth_user)
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user") raise SynapseError(400, "Can only whois a local user")
ret = yield self.handlers.admin_handler.get_whois(target_user) ret = await self.handlers.admin_handler.get_whois(target_user)
return 200, ret return 200, ret
@ -255,9 +251,8 @@ class PurgeHistoryRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, room_id, event_id):
def on_POST(self, request, room_id, event_id): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True) body = parse_json_object_from_request(request, allow_empty_body=True)
@ -270,12 +265,12 @@ class PurgeHistoryRestServlet(RestServlet):
event_id = body.get("purge_up_to_event_id") event_id = body.get("purge_up_to_event_id")
if event_id is not None: if event_id is not None:
event = yield self.store.get_event(event_id) event = await self.store.get_event(event_id)
if event.room_id != room_id: if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.") raise SynapseError(400, "Event is for wrong room.")
token = yield self.store.get_topological_token_for_event(event_id) token = await self.store.get_topological_token_for_event(event_id)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body: elif "purge_up_to_ts" in body:
@ -285,12 +280,10 @@ class PurgeHistoryRestServlet(RestServlet):
400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
) )
stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = ( r = await self.store.get_room_event_after_stream_ordering(
yield self.store.get_room_event_after_stream_ordering( room_id, stream_ordering
room_id, stream_ordering
)
) )
if not r: if not r:
logger.warn( logger.warn(
@ -318,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet):
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
purge_id = yield self.pagination_handler.start_purge_history( purge_id = self.pagination_handler.start_purge_history(
room_id, token, delete_local_events=delete_local_events room_id, token, delete_local_events=delete_local_events
) )
@ -339,9 +332,8 @@ class PurgeHistoryStatusRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, purge_id):
def on_GET(self, request, purge_id): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
purge_status = self.pagination_handler.get_purge_status(purge_id) purge_status = self.pagination_handler.get_purge_status(purge_id)
if purge_status is None: if purge_status is None:
@ -357,9 +349,8 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, target_user_id):
def on_POST(self, request, target_user_id): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True) body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False) erase = body.get("erase", False)
if not isinstance(erase, bool): if not isinstance(erase, bool):
@ -371,7 +362,7 @@ class DeactivateAccountRestServlet(RestServlet):
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
result = yield self._deactivate_account_handler.deactivate_account( result = await self._deactivate_account_handler.deactivate_account(
target_user_id, erase target_user_id, erase
) )
if result: if result:
@ -405,10 +396,9 @@ class ShutdownRoomRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, room_id):
def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user)
yield assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"]) assert_params_in_dict(content, ["new_room_user_id"])
@ -419,7 +409,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE) message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification") room_name = content.get("room_name", "Content Violation Notification")
info = yield self._room_creation_handler.create_room( info = await self._room_creation_handler.create_room(
room_creator_requester, room_creator_requester,
config={ config={
"preset": "public_chat", "preset": "public_chat",
@ -438,9 +428,9 @@ class ShutdownRoomRestServlet(RestServlet):
# This will work even if the room is already blocked, but that is # This will work even if the room is already blocked, but that is
# desirable in case the first attempt at blocking the room failed below. # desirable in case the first attempt at blocking the room failed below.
yield self.store.block_room(room_id, requester_user_id) await self.store.block_room(room_id, requester_user_id)
users = yield self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
kicked_users = [] kicked_users = []
failed_to_kick_users = [] failed_to_kick_users = []
for user_id in users: for user_id in users:
@ -451,7 +441,7 @@ class ShutdownRoomRestServlet(RestServlet):
try: try:
target_requester = create_requester(user_id) target_requester = create_requester(user_id)
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=room_id, room_id=room_id,
@ -461,9 +451,9 @@ class ShutdownRoomRestServlet(RestServlet):
require_consent=False, require_consent=False,
) )
yield self.room_member_handler.forget(target_requester.user, room_id) await self.room_member_handler.forget(target_requester.user, room_id)
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=new_room_id, room_id=new_room_id,
@ -480,7 +470,7 @@ class ShutdownRoomRestServlet(RestServlet):
) )
failed_to_kick_users.append(user_id) failed_to_kick_users.append(user_id)
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester, room_creator_requester,
{ {
"type": "m.room.message", "type": "m.room.message",
@ -491,9 +481,11 @@ class ShutdownRoomRestServlet(RestServlet):
ratelimit=False, ratelimit=False,
) )
aliases_for_room = yield self.store.get_aliases_for_room(room_id) aliases_for_room = await maybe_awaitable(
self.store.get_aliases_for_room(room_id)
)
yield self.store.update_aliases_for_room( await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id room_id, new_room_id, requester_user_id
) )
@ -532,13 +524,12 @@ class ResetPasswordRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks async def on_POST(self, request, target_user_id):
def on_POST(self, request, target_user_id):
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
yield assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
@ -546,7 +537,7 @@ class ResetPasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"]) assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"] new_password = params["new_password"]
yield self._set_password_handler.set_password( await self._set_password_handler.set_password(
target_user_id, new_password, requester target_user_id, new_password, requester
) )
return 200, {} return 200, {}
@ -572,12 +563,11 @@ class GetUsersPaginatedRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, target_user_id):
def on_GET(self, request, target_user_id):
"""Get request to get specific number of users from Synapse. """Get request to get specific number of users from Synapse.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
@ -590,11 +580,10 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_POST(self, request, target_user_id):
def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse.. """Post request to get specific number of users from Synapse..
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
Example: Example:
@ -608,7 +597,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
Returns: Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object. 200 OK with json object {list[dict[str, Any]], count} or empty object.
""" """
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
order = "name" # order by name in user table order = "name" # order by name in user table
@ -618,7 +607,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
start = params["start"] start = params["start"]
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret return 200, ret
@ -641,13 +630,12 @@ class SearchUsersRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, target_user_id):
def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to """Get request to search user table for specific users according to
search term. search term.
This needs user to have a administrator access in Synapse. This needs user to have a administrator access in Synapse.
""" """
yield assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
@ -661,7 +649,7 @@ class SearchUsersRestServlet(RestServlet):
term = parse_string(request, "term", required=True) term = parse_string(request, "term", required=True)
logger.info("term: %s ", term) logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(term) ret = await self.handlers.admin_handler.search_users(term)
return 200, ret return 200, ret
@ -676,15 +664,14 @@ class DeleteGroupAdminRestServlet(RestServlet):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, group_id):
def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user)
yield assert_user_is_admin(self.auth, requester.user)
if not self.is_mine_id(group_id): if not self.is_mine_id(group_id):
raise SynapseError(400, "Can only delete local groups") raise SynapseError(400, "Can only delete local groups")
yield self.group_server.delete_group(group_id, requester.user.to_string()) await self.group_server.delete_group(group_id, requester.user.to_string())
return 200, {} return 200, {}
@ -700,16 +687,15 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if "user_id" not in body: if "user_id" not in body:
raise SynapseError(400, "Missing property 'user_id' in the request body") raise SynapseError(400, "Missing property 'user_id' in the request body")
expiration_ts = yield self.account_activity_handler.renew_account_for_user( expiration_ts = await self.account_activity_handler.renew_account_for_user(
body["user_id"], body["user_id"],
body.get("expiration_ts"), body.get("expiration_ts"),
not body.get("enable_renewal_emails", True), not body.get("enable_renewal_emails", True),

View File

@ -15,8 +15,6 @@
import re import re
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
@ -42,8 +40,7 @@ def historical_admin_path_patterns(path_regex):
) )
@defer.inlineCallbacks async def assert_requester_is_admin(auth, request):
def assert_requester_is_admin(auth, request):
"""Verify that the requester is an admin user """Verify that the requester is an admin user
WARNING: MAKE SURE YOU YIELD ON THE RESULT! WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@ -58,12 +55,11 @@ def assert_requester_is_admin(auth, request):
Raises: Raises:
AuthError if the requester is not an admin AuthError if the requester is not an admin
""" """
requester = yield auth.get_user_by_req(request) requester = await auth.get_user_by_req(request)
yield assert_user_is_admin(auth, requester.user) await assert_user_is_admin(auth, requester.user)
@defer.inlineCallbacks async def assert_user_is_admin(auth, user_id):
def assert_user_is_admin(auth, user_id):
"""Verify that the given user is an admin user """Verify that the given user is an admin user
WARNING: MAKE SURE YOU YIELD ON THE RESULT! WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@ -79,6 +75,6 @@ def assert_user_is_admin(auth, user_id):
AuthError if the user is not an admin AuthError if the user is not an admin
""" """
is_admin = yield auth.is_server_admin(user_id) is_admin = await auth.is_server_admin(user_id)
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_integer from synapse.http.servlet import RestServlet, parse_integer
from synapse.rest.admin._base import ( from synapse.rest.admin._base import (
@ -40,12 +38,11 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, room_id):
def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user)
yield assert_user_is_admin(self.auth, requester.user)
num_quarantined = yield self.store.quarantine_media_ids_in_room( num_quarantined = await self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
@ -62,14 +59,13 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
return 200, {"local": local_mxcs, "remote": remote_mxcs} return 200, {"local": local_mxcs, "remote": remote_mxcs}
@ -81,14 +77,13 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True) before_ts = parse_integer(request, "before_ts", required=True)
logger.info("before_ts: %r", before_ts) logger.info("before_ts: %r", before_ts)
ret = yield self.media_repository.delete_old_remote_media(before_ts) ret = await self.media_repository.delete_old_remote_media(before_ts)
return 200, ret return 200, ret

View File

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
import re import re
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -69,9 +67,8 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__, self.__class__.__name__,
) )
@defer.inlineCallbacks async def on_POST(self, request, txn_id=None):
def on_POST(self, request, txn_id=None): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content")) assert_params_in_dict(body, ("user_id", "content"))
event_type = body.get("type", EventTypes.Message) event_type = body.get("type", EventTypes.Message)
@ -85,7 +82,7 @@ class SendServerNoticeServlet(RestServlet):
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Server notices can only be sent to local users") raise SynapseError(400, "Server notices can only be sent to local users")
event = yield self.snm.send_notice( event = await self.snm.send_notice(
user_id=body["user_id"], user_id=body["user_id"],
type=event_type, type=event_type,
state_key=state_key, state_key=state_key,

View File

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
import re import re
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -59,24 +57,22 @@ class UserAdminServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request)
yield assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Only local users can be admins of this homeserver") raise SynapseError(400, "Only local users can be admins of this homeserver")
is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user) is_admin = await self.handlers.admin_handler.get_user_server_admin(target_user)
is_admin = bool(is_admin) is_admin = bool(is_admin)
return 200, {"admin": is_admin} return 200, {"admin": is_admin}
@defer.inlineCallbacks async def on_PUT(self, request, user_id):
def on_PUT(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user)
yield assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -93,7 +89,7 @@ class UserAdminServlet(RestServlet):
if target_user == auth_user and not set_admin_to: if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.") raise SynapseError(400, "You may not demote yourself.")
yield self.handlers.admin_handler.set_user_server_admin( await self.handlers.admin_handler.set_user_server_admin(
target_user, set_admin_to target_user, set_admin_to
) )

View File

@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union
from six.moves import range from six.moves import range
import attr
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.python import failure from twisted.python import failure
@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb) deferred.addCallbacks(success_cb, failure_cb)
return new_d return new_d
@attr.s(slots=True, frozen=True)
class DoneAwaitable(object):
"""Simple awaitable that returns the provided value.
"""
value = attr.ib()
def __await__(self):
return self
def __iter__(self):
return self
def __next__(self):
raise StopIteration(self.value)
def maybe_awaitable(value):
"""Convert a value to an awaitable if not already an awaitable.
"""
if hasattr(value, "__await__"):
return value
return DoneAwaitable(value)