rest/client/v1/register: use the correct requester in createUser
Signed-off-by: Patrik Oldsberg <patrik.oldsberg@ericsson.com>
This commit is contained in:
parent
3de7c8a4d0
commit
7b5546d077
|
@ -19,7 +19,6 @@ import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.types
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||||
)
|
)
|
||||||
|
@ -370,7 +369,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_or_create_user(self, localpart, displayname, duration_in_ms,
|
def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
|
||||||
password_hash=None):
|
password_hash=None):
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
|
@ -417,9 +416,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
profile_handler = self.hs.get_handlers().profile_handler
|
profile_handler = self.hs.get_handlers().profile_handler
|
||||||
requester = synapse.types.create_requester(user)
|
|
||||||
yield profile_handler.set_displayname(
|
yield profile_handler.set_displayname(
|
||||||
user, requester, displayname
|
user, requester, displayname, by_admin=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
|
@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.types import create_requester
|
||||||
|
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
|
||||||
|
@ -397,9 +398,10 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
if not app_service:
|
if not app_service:
|
||||||
raise SynapseError(403, "Invalid application service token.")
|
raise SynapseError(403, "Invalid application service token.")
|
||||||
|
|
||||||
logger.debug("creating user: %s", user_json)
|
requester = create_requester(app_service.sender)
|
||||||
|
|
||||||
response = yield self._do_create(user_json)
|
logger.debug("creating user: %s", user_json)
|
||||||
|
response = yield self._do_create(requester, user_json)
|
||||||
|
|
||||||
defer.returnValue((200, response))
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
return 403, {}
|
return 403, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_create(self, user_json):
|
def _do_create(self, requester, user_json):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
if "localpart" not in user_json:
|
if "localpart" not in user_json:
|
||||||
|
@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
user_id, token = yield handler.get_or_create_user(
|
user_id, token = yield handler.get_or_create_user(
|
||||||
|
requester=requester,
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
displayname=displayname,
|
displayname=displayname,
|
||||||
duration_in_ms=(duration_seconds * 1000),
|
duration_in_ms=(duration_seconds * 1000),
|
||||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
local_part = "someone"
|
local_part = "someone"
|
||||||
display_name = "someone"
|
display_name = "someone"
|
||||||
user_id = "@someone:test"
|
user_id = "@someone:test"
|
||||||
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
local_part, display_name, duration_ms)
|
requester, local_part, display_name, duration_ms)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
|
@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
local_part = "frank"
|
local_part = "frank"
|
||||||
display_name = "Frank"
|
display_name = "Frank"
|
||||||
user_id = "@frank:test"
|
user_id = "@frank:test"
|
||||||
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
local_part, display_name, duration_ms)
|
requester, local_part, display_name, duration_ms)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
|
@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.request.args = {}
|
self.request.args = {}
|
||||||
|
|
||||||
self.appservice = None
|
|
||||||
self.auth = Mock(get_appservice_by_req=Mock(
|
|
||||||
side_effect=lambda x: defer.succeed(self.appservice))
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth_result = (False, None, None, None)
|
|
||||||
self.auth_handler = Mock(
|
|
||||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
|
||||||
get_session_data=Mock(return_value=None)
|
|
||||||
)
|
|
||||||
self.registration_handler = Mock()
|
self.registration_handler = Mock()
|
||||||
self.identity_handler = Mock()
|
|
||||||
self.login_handler = Mock()
|
|
||||||
|
|
||||||
# do the dance to hook it up to the hs global
|
self.appservice = Mock(sender="@as:test")
|
||||||
self.handlers = Mock(
|
self.datastore = Mock(
|
||||||
auth_handler=self.auth_handler,
|
get_app_service_by_token=Mock(return_value=self.appservice)
|
||||||
|
)
|
||||||
|
|
||||||
|
# do the dance to hook things up to the hs global
|
||||||
|
handlers = Mock(
|
||||||
registration_handler=self.registration_handler,
|
registration_handler=self.registration_handler,
|
||||||
identity_handler=self.identity_handler,
|
|
||||||
login_handler=self.login_handler
|
|
||||||
)
|
)
|
||||||
self.hs = Mock()
|
self.hs = Mock()
|
||||||
self.hs.hostname = "supergbig~testing~thing.com"
|
self.hs.hostname = "superbig~testing~thing.com"
|
||||||
self.hs.get_auth = Mock(return_value=self.auth)
|
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
self.hs.get_handlers = Mock(return_value=handlers)
|
||||||
self.hs.config.enable_registration = True
|
|
||||||
# init the thing we're testing
|
|
||||||
self.servlet = CreateUserRestServlet(self.hs)
|
self.servlet = CreateUserRestServlet(self.hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
Loading…
Reference in New Issue