Merge pull request #478 from matrix-org/daniel/userobject
Introduce a User object I'm sick of passing around more and more things as tuple items around the whole world, and needing to edit every call site every time there is more information about a user. So pass them around together as an object. This object has incredibly poorly named fields because we have a convention that `user` indicates a UserID object, and `user_id` indicates a string. I tried to clean up the whole repo to fix this, but gave up. So instead, I introduce a second convention. A user_object is a User, and a user_id_object is a UserId. I may have cried a little bit.
This commit is contained in:
commit
42aa1f3f33
|
@ -22,7 +22,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import RoomID, UserID, EventID
|
||||
from synapse.types import Requester, RoomID, UserID, EventID
|
||||
from synapse.util.logutils import log_function
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
|
@ -534,7 +534,9 @@ class Auth(object):
|
|||
|
||||
request.authenticated_entity = user_id
|
||||
|
||||
defer.returnValue((UserID.from_string(user_id), "", False))
|
||||
defer.returnValue(
|
||||
Requester(UserID.from_string(user_id), "", False)
|
||||
)
|
||||
return
|
||||
except KeyError:
|
||||
pass # normal users won't have the user_id query parameter set.
|
||||
|
@ -564,7 +566,7 @@ class Auth(object):
|
|||
|
||||
request.authenticated_entity = user.to_string()
|
||||
|
||||
defer.returnValue((user, token_id, is_guest,))
|
||||
defer.returnValue(Requester(user, token_id, is_guest))
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
|
|
|
@ -31,8 +31,9 @@ class WhoisRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(auth_user)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
auth_user = requester.user
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin and target_user != auth_user:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
|
|
@ -69,9 +69,9 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
|||
|
||||
try:
|
||||
# try to auth as a user
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
try:
|
||||
user_id = user.to_string()
|
||||
user_id = requester.user.to_string()
|
||||
yield dir_handler.create_association(
|
||||
user_id, room_alias, room_id, servers
|
||||
)
|
||||
|
@ -116,8 +116,8 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
|||
# fallback to default user behaviour if they aren't an AS
|
||||
pass
|
||||
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
is_admin = yield self.auth.is_server_admin(user)
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You need to be a server admin")
|
||||
|
|
|
@ -34,10 +34,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user, _, is_guest = yield self.auth.get_user_by_req(
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True
|
||||
allow_guest=True,
|
||||
)
|
||||
is_guest = requester.is_guest
|
||||
room_id = None
|
||||
if is_guest:
|
||||
if "room_id" not in request.args:
|
||||
|
@ -56,9 +57,13 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
as_client_event = "raw" not in request.args
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
auth_user.to_string(), pagin_config, timeout=timeout,
|
||||
as_client_event=as_client_event, affect_presence=(not is_guest),
|
||||
room_id=room_id, is_guest=is_guest
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
as_client_event=as_client_event,
|
||||
affect_presence=(not is_guest),
|
||||
room_id=room_id,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
except:
|
||||
logger.exception("Event stream failed")
|
||||
|
@ -80,9 +85,9 @@ class EventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, event_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.event_handler
|
||||
event = yield handler.get_event(auth_user, event_id)
|
||||
event = yield handler.get_event(requester.user, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
|
|
|
@ -25,13 +25,13 @@ class InitialSyncRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
as_client_event = "raw" not in request.args
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
handler = self.handlers.message_handler
|
||||
include_archived = request.args.get("archived", None) == ["true"]
|
||||
content = yield handler.snapshot_all_rooms(
|
||||
user_id=user.to_string(),
|
||||
user_id=requester.user.to_string(),
|
||||
pagin_config=pagination_config,
|
||||
as_client_event=as_client_event,
|
||||
include_archived=include_archived,
|
||||
|
|
|
@ -32,17 +32,17 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
state = yield self.handlers.presence_handler.get_state(
|
||||
target_user=user, auth_user=auth_user)
|
||||
target_user=user, auth_user=requester.user)
|
||||
|
||||
defer.returnValue((200, state))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
state = {}
|
||||
|
@ -64,7 +64,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
|||
raise SynapseError(400, "Unable to parse state")
|
||||
|
||||
yield self.handlers.presence_handler.set_state(
|
||||
target_user=user, auth_user=auth_user, state=state)
|
||||
target_user=user, auth_user=requester.user, state=state)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -77,13 +77,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User not hosted on this Home Server")
|
||||
|
||||
if auth_user != user:
|
||||
if requester.user != user:
|
||||
raise SynapseError(400, "Cannot get another user's presence list")
|
||||
|
||||
presence = yield self.handlers.presence_handler.get_presence_list(
|
||||
|
@ -97,13 +97,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User not hosted on this Home Server")
|
||||
|
||||
if auth_user != user:
|
||||
if requester.user != user:
|
||||
raise SynapseError(
|
||||
400, "Cannot modify another user's presence list")
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
try:
|
||||
|
@ -47,7 +47,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_displayname(
|
||||
user, auth_user, new_name)
|
||||
user, requester.user, new_name)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
try:
|
||||
|
@ -80,7 +80,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_avatar_url(
|
||||
user, auth_user, new_name)
|
||||
user, requester.user, new_name)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
except InvalidRuleException as e:
|
||||
raise SynapseError(400, e.message)
|
||||
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
||||
raise SynapseError(400, "rule_id may not contain slashes")
|
||||
|
@ -51,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
content = _parse_json(request)
|
||||
|
||||
if 'attr' in spec:
|
||||
self.set_rule_attr(user.to_string(), spec, content)
|
||||
self.set_rule_attr(requester.user, spec, content)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
try:
|
||||
|
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
try:
|
||||
yield self.hs.get_datastore().add_push_rule(
|
||||
user_name=user.to_string(),
|
||||
user_name=requester.user.to_string(),
|
||||
rule_id=_namespaced_rule_id_from_spec(spec),
|
||||
priority_class=priority_class,
|
||||
conditions=conditions,
|
||||
|
@ -92,13 +92,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
def on_DELETE(self, request):
|
||||
spec = _rule_spec_from_path(request.postpath)
|
||||
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
|
||||
try:
|
||||
yield self.hs.get_datastore().delete_push_rule(
|
||||
user.to_string(), namespaced_rule_id
|
||||
requester.user.to_string(), namespaced_rule_id
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
except StoreError as e:
|
||||
|
@ -109,7 +109,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
|
||||
# we build up the full structure and then decide which bits of it
|
||||
# to send which means doing unnecessary work sometimes but is
|
||||
|
|
|
@ -30,7 +30,8 @@ class PusherRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
|
@ -71,7 +72,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
|||
try:
|
||||
yield pusher_pool.add_pusher(
|
||||
user_name=user.to_string(),
|
||||
access_token=token_id,
|
||||
access_token=requester.access_token_id,
|
||||
profile_tag=content['profile_tag'],
|
||||
kind=content['kind'],
|
||||
app_id=content['app_id'],
|
||||
|
|
|
@ -61,10 +61,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
room_config = self.get_room_config(request)
|
||||
info = yield self.make_room(room_config, auth_user, None)
|
||||
info = yield self.make_room(
|
||||
room_config,
|
||||
requester.user,
|
||||
None,
|
||||
)
|
||||
room_config.update(info)
|
||||
defer.returnValue((200, info))
|
||||
|
||||
|
@ -124,15 +128,15 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_type, state_key):
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
user_id=user.to_string(),
|
||||
user_id=requester.user.to_string(),
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
is_guest=is_guest,
|
||||
is_guest=requester.is_guest,
|
||||
)
|
||||
|
||||
if not data:
|
||||
|
@ -143,7 +147,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
|
@ -151,7 +155,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
if state_key is not None:
|
||||
|
@ -159,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.create_and_send_event(
|
||||
event_dict, token_id=token_id, txn_id=txn_id,
|
||||
event_dict, token_id=requester.access_token_id, txn_id=txn_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -175,7 +179,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_type, txn_id=None):
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
content = _parse_json(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
|
@ -184,9 +188,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
"sender": requester.user.to_string(),
|
||||
},
|
||||
token_id=token_id,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
|
@ -220,9 +224,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_identifier, txn_id=None):
|
||||
user, token_id, is_guest = yield self.auth.get_user_by_req(
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True
|
||||
allow_guest=True,
|
||||
)
|
||||
|
||||
# the identifier could be a room alias or a room id. Try one then the
|
||||
|
@ -241,24 +245,27 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
|||
|
||||
if is_room_alias:
|
||||
handler = self.handlers.room_member_handler
|
||||
ret_dict = yield handler.join_room_alias(user, identifier)
|
||||
ret_dict = yield handler.join_room_alias(
|
||||
requester.user,
|
||||
identifier,
|
||||
)
|
||||
defer.returnValue((200, ret_dict))
|
||||
else: # room id
|
||||
msg_handler = self.handlers.message_handler
|
||||
content = {"membership": Membership.JOIN}
|
||||
if is_guest:
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
yield msg_handler.create_and_send_event(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"content": content,
|
||||
"room_id": identifier.to_string(),
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"sender": requester.user.to_string(),
|
||||
"state_key": requester.user.to_string(),
|
||||
},
|
||||
token_id=token_id,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id,
|
||||
is_guest=is_guest,
|
||||
is_guest=requester.is_guest,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"room_id": identifier.to_string()}))
|
||||
|
@ -296,11 +303,11 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.message_handler
|
||||
events = yield handler.get_state_events(
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
user_id=requester.user.to_string(),
|
||||
)
|
||||
|
||||
chunk = []
|
||||
|
@ -315,7 +322,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||
try:
|
||||
presence_handler = self.handlers.presence_handler
|
||||
presence_state = yield presence_handler.get_state(
|
||||
target_user=target_user, auth_user=user
|
||||
target_user=target_user,
|
||||
auth_user=requester.user,
|
||||
)
|
||||
event["content"].update(presence_state)
|
||||
except:
|
||||
|
@ -332,7 +340,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(
|
||||
request, default_limit=10,
|
||||
)
|
||||
|
@ -340,8 +348,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
handler = self.handlers.message_handler
|
||||
msgs = yield handler.get_messages(
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
is_guest=is_guest,
|
||||
user_id=requester.user.to_string(),
|
||||
is_guest=requester.is_guest,
|
||||
pagin_config=pagination_config,
|
||||
as_client_event=as_client_event
|
||||
)
|
||||
|
@ -355,13 +363,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
handler = self.handlers.message_handler
|
||||
# Get all the current state for this room
|
||||
events = yield handler.get_state_events(
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
is_guest=is_guest,
|
||||
user_id=requester.user.to_string(),
|
||||
is_guest=requester.is_guest,
|
||||
)
|
||||
defer.returnValue((200, events))
|
||||
|
||||
|
@ -372,13 +380,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
content = yield self.handlers.message_handler.room_initial_sync(
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
user_id=requester.user.to_string(),
|
||||
pagin_config=pagination_config,
|
||||
is_guest=is_guest,
|
||||
is_guest=requester.is_guest,
|
||||
)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
@ -394,12 +402,16 @@ class RoomEventContext(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
limit = int(request.args.get("limit", [10])[0])
|
||||
|
||||
results = yield self.handlers.room_context_handler.get_event_context(
|
||||
user, room_id, event_id, limit, is_guest
|
||||
requester.user,
|
||||
room_id,
|
||||
event_id,
|
||||
limit,
|
||||
requester.is_guest,
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
@ -429,14 +441,18 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||
user, token_id, is_guest = yield self.auth.get_user_by_req(
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True
|
||||
allow_guest=True,
|
||||
)
|
||||
user = requester.user
|
||||
|
||||
effective_membership_action = membership_action
|
||||
|
||||
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
|
||||
if requester.is_guest and membership_action not in {
|
||||
Membership.JOIN,
|
||||
Membership.LEAVE
|
||||
}:
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
content = _parse_json(request)
|
||||
|
@ -451,7 +467,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
content["medium"],
|
||||
content["address"],
|
||||
content["id_server"],
|
||||
token_id,
|
||||
requester.access_token_id,
|
||||
txn_id
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -473,7 +489,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
msg_handler = self.handlers.message_handler
|
||||
|
||||
content = {"membership": unicode(effective_membership_action)}
|
||||
if is_guest:
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
yield msg_handler.create_and_send_event(
|
||||
|
@ -484,9 +500,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
"sender": user.to_string(),
|
||||
"state_key": state_key,
|
||||
},
|
||||
token_id=token_id,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id,
|
||||
is_guest=is_guest,
|
||||
is_guest=requester.is_guest,
|
||||
)
|
||||
|
||||
if membership_action == "forget":
|
||||
|
@ -524,7 +540,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id, txn_id=None):
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
content = _parse_json(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
|
@ -533,10 +549,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
|||
"type": EventTypes.Redaction,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
"sender": requester.user.to_string(),
|
||||
"redacts": event_id,
|
||||
},
|
||||
token_id=token_id,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
|
@ -564,7 +580,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, user_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
room_id = urllib.unquote(room_id)
|
||||
target_user = UserID.from_string(urllib.unquote(user_id))
|
||||
|
@ -576,14 +592,14 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
|||
if content["typing"]:
|
||||
yield typing_handler.started_typing(
|
||||
target_user=target_user,
|
||||
auth_user=auth_user,
|
||||
auth_user=requester.user,
|
||||
room_id=room_id,
|
||||
timeout=content.get("timeout", 30000),
|
||||
)
|
||||
else:
|
||||
yield typing_handler.stopped_typing(
|
||||
target_user=target_user,
|
||||
auth_user=auth_user,
|
||||
auth_user=requester.user,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
|
@ -597,12 +613,16 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
batch = request.args.get("next_batch", [None])[0]
|
||||
results = yield self.handlers.search_handler.search(auth_user, content, batch)
|
||||
results = yield self.handlers.search_handler.search(
|
||||
requester.user,
|
||||
content,
|
||||
batch,
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
turnUris = self.hs.config.turn_uris
|
||||
turnSecret = self.hs.config.turn_shared_secret
|
||||
|
@ -37,7 +37,7 @@ class VoipRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, {}))
|
||||
|
||||
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
|
||||
username = "%d:%s" % (expiry, auth_user.to_string())
|
||||
username = "%d:%s" % (expiry, requester.user.to_string())
|
||||
|
||||
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
|
||||
# We need to use standard padded base64 encoding here
|
||||
|
|
|
@ -55,10 +55,11 @@ class PasswordRestServlet(RestServlet):
|
|||
|
||||
if LoginType.PASSWORD in result:
|
||||
# if using password, they should also be logged in
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if auth_user.to_string() != result[LoginType.PASSWORD]:
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
if requester_user_id.to_string() != result[LoginType.PASSWORD]:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
user_id = auth_user.to_string()
|
||||
user_id = requester_user_id
|
||||
elif LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
if 'medium' not in threepid or 'address' not in threepid:
|
||||
|
@ -102,10 +103,10 @@ class ThreepidRestServlet(RestServlet):
|
|||
def on_GET(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
threepids = yield self.hs.get_datastore().user_get_threepids(
|
||||
auth_user.to_string()
|
||||
requester.user.to_string()
|
||||
)
|
||||
|
||||
defer.returnValue((200, {'threepids': threepids}))
|
||||
|
@ -120,7 +121,8 @@ class ThreepidRestServlet(RestServlet):
|
|||
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
||||
threePidCreds = body['threePidCreds']
|
||||
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
||||
|
||||
|
@ -135,7 +137,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
raise SynapseError(500, "Invalid response from ID Server")
|
||||
|
||||
yield self.auth_handler.add_threepid(
|
||||
auth_user.to_string(),
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
|
@ -144,10 +146,10 @@ class ThreepidRestServlet(RestServlet):
|
|||
if 'bind' in body and body['bind']:
|
||||
logger.debug(
|
||||
"Binding emails %s to %s",
|
||||
threepid, auth_user.to_string()
|
||||
threepid, user_id
|
||||
)
|
||||
yield self.identity_handler.bind_threepid(
|
||||
threePidCreds, auth_user.to_string()
|
||||
threePidCreds, user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -43,8 +43,8 @@ class AccountDataServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id, account_data_type):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add account data for other users.")
|
||||
|
||||
try:
|
||||
|
@ -82,8 +82,8 @@ class RoomAccountDataServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id, room_id, account_data_type):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add account data for other users.")
|
||||
|
||||
try:
|
||||
|
|
|
@ -40,9 +40,9 @@ class GetFilterRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, filter_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if target_user != auth_user:
|
||||
if target_user != requester.user:
|
||||
raise AuthError(403, "Cannot get filters for other users")
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
|
@ -76,9 +76,9 @@ class CreateFilterRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if target_user != auth_user:
|
||||
if target_user != requester.user:
|
||||
raise AuthError(403, "Cannot create filters for other users")
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
|
|
|
@ -64,8 +64,8 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, device_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user_id = auth_user.to_string()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
# TODO: Check that the device_id matches that in the authentication
|
||||
# or derive the device_id from the authentication instead.
|
||||
try:
|
||||
|
@ -78,8 +78,8 @@ class KeyUploadServlet(RestServlet):
|
|||
device_keys = body.get("device_keys", None)
|
||||
if device_keys:
|
||||
logger.info(
|
||||
"Updating device_keys for device %r for user %r at %d",
|
||||
device_id, auth_user, time_now
|
||||
"Updating device_keys for device %r for user %s at %d",
|
||||
device_id, user_id, time_now
|
||||
)
|
||||
# TODO: Sign the JSON with the server key
|
||||
yield self.store.set_e2e_device_keys(
|
||||
|
@ -109,8 +109,8 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user_id = auth_user.to_string()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
@ -182,8 +182,8 @@ class KeyQueryServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user_id = auth_user.to_string()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
auth_user_id = requester.user.to_string()
|
||||
user_id = user_id if user_id else auth_user_id
|
||||
device_ids = [device_id] if device_id else []
|
||||
result = yield self.handle_request(
|
||||
|
|
|
@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if receipt_type != "m.read":
|
||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
||||
|
@ -48,7 +48,7 @@ class ReceiptRestServlet(RestServlet):
|
|||
yield self.receipts_handler.received_client_receipt(
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id=user.to_string(),
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=event_id
|
||||
)
|
||||
|
||||
|
|
|
@ -85,9 +85,10 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
user, token_id, is_guest = yield self.auth.get_user_by_req(
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request, allow_guest=True
|
||||
)
|
||||
user = requester.user
|
||||
|
||||
timeout = parse_integer(request, "timeout", default=0)
|
||||
since = parse_string(request, "since")
|
||||
|
@ -123,7 +124,7 @@ class SyncRestServlet(RestServlet):
|
|||
sync_config = SyncConfig(
|
||||
user=user,
|
||||
filter=filter,
|
||||
is_guest=is_guest,
|
||||
is_guest=requester.is_guest,
|
||||
)
|
||||
|
||||
if since is not None:
|
||||
|
@ -146,15 +147,15 @@ class SyncRestServlet(RestServlet):
|
|||
time_now = self.clock.time_msec()
|
||||
|
||||
joined = self.encode_joined(
|
||||
sync_result.joined, filter, time_now, token_id
|
||||
sync_result.joined, filter, time_now, requester.access_token_id
|
||||
)
|
||||
|
||||
invited = self.encode_invited(
|
||||
sync_result.invited, filter, time_now, token_id
|
||||
sync_result.invited, filter, time_now, requester.access_token_id
|
||||
)
|
||||
|
||||
archived = self.encode_archived(
|
||||
sync_result.archived, filter, time_now, token_id
|
||||
sync_result.archived, filter, time_now, requester.access_token_id
|
||||
)
|
||||
|
||||
response_content = {
|
||||
|
|
|
@ -42,8 +42,8 @@ class TagListServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, room_id):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot get tags for other users.")
|
||||
|
||||
tags = yield self.store.get_tags_for_room(user_id, room_id)
|
||||
|
@ -68,8 +68,8 @@ class TagServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id, room_id, tag):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add tags for other users.")
|
||||
|
||||
try:
|
||||
|
@ -88,8 +88,8 @@ class TagServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, user_id, room_id, tag):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot add tags for other users.")
|
||||
|
||||
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
|
||||
|
|
|
@ -66,11 +66,11 @@ class ContentRepoResource(resource.Resource):
|
|||
@defer.inlineCallbacks
|
||||
def map_request_to_name(self, request):
|
||||
# auth the user
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# namespace all file uploads on the user
|
||||
prefix = base64.urlsafe_b64encode(
|
||||
auth_user.to_string()
|
||||
requester.user.to_string()
|
||||
).replace('=', '')
|
||||
|
||||
# use a random string for the main portion
|
||||
|
@ -94,7 +94,7 @@ class ContentRepoResource(resource.Resource):
|
|||
file_name = prefix + main_part + suffix
|
||||
file_path = os.path.join(self.directory, file_name)
|
||||
logger.info("User %s is uploading a file to path %s",
|
||||
auth_user.to_string(),
|
||||
request.user.user_id.to_string(),
|
||||
file_path)
|
||||
|
||||
# keep trying to make a non-clashing file, with a sensible max attempts
|
||||
|
|
|
@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
|
|||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
|
@ -110,7 +110,7 @@ class UploadResource(BaseMediaResource):
|
|||
|
||||
content_uri = yield self.create_content(
|
||||
media_type, upload_name, request.content.read(),
|
||||
content_length, auth_user
|
||||
content_length, requester.user
|
||||
)
|
||||
|
||||
respond_with_json(
|
||||
|
|
|
@ -18,6 +18,9 @@ from synapse.api.errors import SynapseError
|
|||
from collections import namedtuple
|
||||
|
||||
|
||||
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
|
||||
|
||||
|
||||
class DomainSpecificString(
|
||||
namedtuple("DomainSpecificString", ("localpart", "domain"))
|
||||
):
|
||||
|
|
|
@ -51,8 +51,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
request = Mock(args={})
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(user.to_string(), self.test_user)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
||||
def test_get_user_by_req_user_bad_token(self):
|
||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||
|
@ -86,8 +86,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
request = Mock(args={})
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(user.to_string(), self.test_user)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
||||
def test_get_user_by_req_appservice_bad_token(self):
|
||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||
|
@ -121,8 +121,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
request.args["access_token"] = [self.test_token]
|
||||
request.args["user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(user.to_string(), masquerading_user_id)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(requester.user.to_string(), masquerading_user_id)
|
||||
|
||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||
masquerading_user_id = "@doppelganger:matrix.org"
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Tests REST events for /presence paths."""
|
||||
|
||||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -26,7 +25,7 @@ from synapse.api.constants import PresenceState
|
|||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.rest.client.v1 import presence
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.types import UserID
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
from collections import namedtuple
|
||||
|
@ -301,7 +300,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
|||
hs.get_clock().time_msec.return_value = 1000000
|
||||
|
||||
def _get_user_by_req(req=None, allow_guest=False):
|
||||
return (UserID.from_string(myid), "", False)
|
||||
return Requester(UserID.from_string(myid), "", False)
|
||||
|
||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||
|
||||
|
|
|
@ -14,16 +14,15 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Tests REST events for /profile paths."""
|
||||
|
||||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
from mock import Mock
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.types import UserID
|
||||
from synapse.types import Requester, UserID
|
||||
|
||||
from synapse.rest.client.v1 import profile
|
||||
|
||||
|
@ -53,7 +52,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
def _get_user_by_req(request=None, allow_guest=False):
|
||||
return (UserID.from_string(myid), "", False)
|
||||
return Requester(UserID.from_string(myid), "", False)
|
||||
|
||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||
|
||||
|
|
Loading…
Reference in New Issue