Add `create_requester` function

Wrap the `Requester` constructor with a function which provides sensible
defaults, and use it throughout
This commit is contained in:
Richard van der Hoff 2016-07-26 16:46:53 +01:00
parent 33d7776473
commit eb359eced4
11 changed files with 106 additions and 80 deletions

View File

@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import pymacaroons
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import synapse.types
import pymacaroons from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -566,8 +566,7 @@ class Auth(object):
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
defer.Deferred: resolves to a namedtuple including "user" (UserID) defer.Deferred: resolves to a ``synapse.types.Requester`` object
"access_token_id" (int), "is_guest" (bool)
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -576,9 +575,7 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args) user_id = yield self._get_appservice_user_id(request.args)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(synapse.types.create_requester(user_id))
Requester(UserID.from_string(user_id), "", False)
)
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(access_token, rights)
@ -612,7 +609,8 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(Requester(user, token_id, is_guest)) defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError import synapse.types
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester from synapse.api.errors import LimitExceededError
from synapse.types import UserID
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -124,7 +124,8 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = Requester(target_user, "", True) requester = synapse.types.create_requester(
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership( yield handler.update_membership(
requester, requester,

View File

@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester from synapse.types import UserID
from ._base import BaseHandler from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the user isn't a guest because we don't let guests set
# profile or avatar data. # profile or avatar data.
requester = Requester(user, "", False) # XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, user,

View File

@ -14,18 +14,19 @@
# limitations under the License. # limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID, Requester import synapse.types
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.types import UserID
import logging from synapse.util.async import run_on_reactor
import urllib from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, Requester(user, token, False), displayname user, requester, displayname
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View File

@ -14,24 +14,22 @@
# limitations under the License. # limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
from ._base import BaseHandler import synapse.types
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, EventTypes, Membership,
) )
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
) )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = Requester(target_user, None, False) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context) prev_event = message_handler.deduplicate_state_event(event, context)

View File

@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns from ._base import client_v2_patterns
import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError
from collections import namedtuple from collections import namedtuple
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) Requester = namedtuple("Requester",
["user", "access_token_id", "is_guest", "device_id"])
"""
Represents the user making a request
Attributes:
user (UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time
"""
def create_requester(user_id, access_token_id=None, is_guest=False,
device_id=None):
"""
Create a new ``Requester`` object
Args:
user_id (str|UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time
Returns:
Requester
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
return Requester(user_id, access_token_id, is_guest, device_id)
def get_domain_from_id(string): def get_domain_from_id(string):

View File

@ -19,11 +19,12 @@ from twisted.internet import defer
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
import synapse.types
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID from synapse.types import UserID
from tests.utils import setup_test_homeserver, requester_for_user from tests.utils import setup_test_homeserver
class ProfileHandlers(object): class ProfileHandlers(object):
@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name(self): def test_set_my_name(self):
yield self.handler.set_displayname( yield self.handler.set_displayname(
self.frank, self.frank,
requester_for_user(self.frank), synapse.types.create_requester(self.frank),
"Frank Jr." "Frank Jr."
) )
@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = self.handler.set_displayname( d = self.handler.set_displayname(
self.frank, self.frank,
requester_for_user(self.bob), synapse.types.create_requester(self.bob),
"Frank Jr." "Frank Jr."
) )
@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
yield self.handler.set_avatar_url( yield self.handler.set_avatar_url(
self.frank, requester_for_user(self.frank), "http://my.server/pic.gif" self.frank, synapse.types.create_requester(self.frank),
"http://my.server/pic.gif"
) )
self.assertEquals( self.assertEquals(

View File

@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.replication.resource import ReplicationResource
from synapse.types import Requester, UserID
from twisted.internet import defer
from tests import unittest
from tests.utils import setup_test_homeserver, requester_for_user
from mock import Mock, NonCallableMock
import json
import contextlib import contextlib
import json
from mock import Mock, NonCallableMock
from twisted.internet import defer
import synapse.types
from synapse.replication.resource import ReplicationResource
from synapse.types import UserID
from tests import unittest
from tests.utils import setup_test_homeserver
class ReplicationResourceCase(unittest.TestCase): class ReplicationResourceCase(unittest.TestCase):
@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase):
def test_events_and_state(self): def test_events_and_state(self):
get = self.get(events="-1", state="-1", timeout="0") get = self.get(events="-1", state="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room( yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {} synapse.types.create_requester(self.user), {}
) )
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase):
def send_text_message(self, room_id, message): def send_text_message(self, room_id, message):
handler = self.hs.get_handlers().message_handler handler = self.hs.get_handlers().message_handler
event = yield handler.create_and_send_nonmember_event( event = yield handler.create_and_send_nonmember_event(
requester_for_user(self.user), synapse.types.create_requester(self.user),
{ {
"type": "m.room.message", "type": "m.room.message",
"content": {"body": "message", "msgtype": "m.text"}, "content": {"body": "message", "msgtype": "m.text"},
@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self): def create_room(self):
result = yield self.hs.get_handlers().room_creation_handler.create_room( result = yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {} synapse.types.create_requester(self.user), {}
) )
defer.returnValue(result["room_id"]) defer.returnValue(result["room_id"])

View File

@ -14,17 +14,14 @@
# limitations under the License. # limitations under the License.
"""Tests REST events for /profile paths.""" """Tests REST events for /profile paths."""
from tests import unittest from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from mock import Mock import synapse.types
from ....utils import MockHttpResource, setup_test_homeserver
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.types import Requester, UserID
from synapse.rest.client.v1 import profile from synapse.rest.client.v1 import profile
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test" myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False) return synapse.types.create_requester(myid)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req

View File

@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.federation.transport import server from synapse.federation.transport import server
from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -512,7 +511,3 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls "call(%s)" % _format_call(c[0], c[1]) for c in calls
]) ])
) )
def requester_for_user(user):
return Requester(user, None, False)