Add `create_requester` function
Wrap the `Requester` constructor with a function which provides sensible defaults, and use it throughout
This commit is contained in:
parent
33d7776473
commit
eb359eced4
|
@ -13,22 +13,22 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pymacaroons
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
|
||||||
from synapse.types import Requester, UserID, get_domain_from_id
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
|
||||||
from synapse.util.metrics import Measure
|
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
import logging
|
import synapse.types
|
||||||
import pymacaroons
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
|
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||||
|
from synapse.types import UserID, get_domain_from_id
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -566,8 +566,7 @@ class Auth(object):
|
||||||
Args:
|
Args:
|
||||||
request - An HTTP request with an access_token query parameter.
|
request - An HTTP request with an access_token query parameter.
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to a namedtuple including "user" (UserID)
|
defer.Deferred: resolves to a ``synapse.types.Requester`` object
|
||||||
"access_token_id" (int), "is_guest" (bool)
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
|
@ -576,9 +575,7 @@ class Auth(object):
|
||||||
user_id = yield self._get_appservice_user_id(request.args)
|
user_id = yield self._get_appservice_user_id(request.args)
|
||||||
if user_id:
|
if user_id:
|
||||||
request.authenticated_entity = user_id
|
request.authenticated_entity = user_id
|
||||||
defer.returnValue(
|
defer.returnValue(synapse.types.create_requester(user_id))
|
||||||
Requester(UserID.from_string(user_id), "", False)
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token = request.args["access_token"][0]
|
access_token = request.args["access_token"][0]
|
||||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||||
|
@ -612,7 +609,8 @@ class Auth(object):
|
||||||
|
|
||||||
request.authenticated_entity = user.to_string()
|
request.authenticated_entity = user.to_string()
|
||||||
|
|
||||||
defer.returnValue(Requester(user, token_id, is_guest))
|
defer.returnValue(synapse.types.create_requester(
|
||||||
|
user, token_id, is_guest, device_id))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
import synapse.types
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import UserID, Requester
|
from synapse.api.errors import LimitExceededError
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -124,7 +124,8 @@ class BaseHandler(object):
|
||||||
# and having homeservers have their own users leave keeps more
|
# and having homeservers have their own users leave keeps more
|
||||||
# of that decision-making and control local to the guest-having
|
# of that decision-making and control local to the guest-having
|
||||||
# homeserver.
|
# homeserver.
|
||||||
requester = Requester(target_user, "", True)
|
requester = synapse.types.create_requester(
|
||||||
|
target_user, is_guest=True)
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
handler = self.hs.get_handlers().room_member_handler
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
|
|
|
@ -13,15 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.types
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.types import UserID, Requester
|
from synapse.types import UserID
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
# Assume the user isn't a guest because we don't let guests set
|
# Assume the user isn't a guest because we don't let guests set
|
||||||
# profile or avatar data.
|
# profile or avatar data.
|
||||||
requester = Requester(user, "", False)
|
# XXX why are we recreating `requester` here for each room?
|
||||||
|
# what was wrong with the `requester` we were passed?
|
||||||
|
requester = synapse.types.create_requester(user)
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
user,
|
user,
|
||||||
|
|
|
@ -14,18 +14,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains functions for registering clients."""
|
"""Contains functions for registering clients."""
|
||||||
|
import logging
|
||||||
|
import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import UserID, Requester
|
import synapse.types
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||||
)
|
)
|
||||||
from ._base import BaseHandler
|
|
||||||
from synapse.util.async import run_on_reactor
|
|
||||||
from synapse.http.client import CaptchaServerHttpClient
|
from synapse.http.client import CaptchaServerHttpClient
|
||||||
|
from synapse.types import UserID
|
||||||
import logging
|
from synapse.util.async import run_on_reactor
|
||||||
import urllib
|
from ._base import BaseHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
profile_handler = self.hs.get_handlers().profile_handler
|
profile_handler = self.hs.get_handlers().profile_handler
|
||||||
|
requester = synapse.types.create_requester(user)
|
||||||
yield profile_handler.set_displayname(
|
yield profile_handler.set_displayname(
|
||||||
user, Requester(user, token, False), displayname
|
user, requester, displayname
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
|
@ -14,24 +14,22 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
from signedjson.sign import verify_signed_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
from ._base import BaseHandler
|
import synapse.types
|
||||||
|
|
||||||
from synapse.types import UserID, RoomID, Requester
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventTypes, Membership,
|
EventTypes, Membership,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import AuthError, SynapseError, Codes
|
from synapse.api.errors import AuthError, SynapseError, Codes
|
||||||
|
from synapse.types import UserID, RoomID
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.distributor import user_left_room, user_joined_room
|
from synapse.util.distributor import user_left_room, user_joined_room
|
||||||
|
from ._base import BaseHandler
|
||||||
from signedjson.sign import verify_signed_json
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
|
||||||
|
|
||||||
from unpaddedbase64 import decode_base64
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
||||||
else:
|
else:
|
||||||
requester = Requester(target_user, None, False)
|
requester = synapse.types.create_requester(target_user)
|
||||||
|
|
||||||
message_handler = self.hs.get_handlers().message_handler
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||||
|
|
|
@ -13,18 +13,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
from canonicaljson import encode_canonical_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
import logging
|
|
||||||
import simplejson as json
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
|
Requester = namedtuple("Requester",
|
||||||
|
["user", "access_token_id", "is_guest", "device_id"])
|
||||||
|
"""
|
||||||
|
Represents the user making a request
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
user (UserID): id of the user making the request
|
||||||
|
access_token_id (int|None): *ID* of the access token used for this
|
||||||
|
request, or None if it came via the appservice API or similar
|
||||||
|
is_guest (bool): True if the user making this request is a guest user
|
||||||
|
device_id (str|None): device_id which was set at authentication time
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_requester(user_id, access_token_id=None, is_guest=False,
|
||||||
|
device_id=None):
|
||||||
|
"""
|
||||||
|
Create a new ``Requester`` object
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str|UserID): id of the user making the request
|
||||||
|
access_token_id (int|None): *ID* of the access token used for this
|
||||||
|
request, or None if it came via the appservice API or similar
|
||||||
|
is_guest (bool): True if the user making this request is a guest user
|
||||||
|
device_id (str|None): device_id which was set at authentication time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Requester
|
||||||
|
"""
|
||||||
|
if not isinstance(user_id, UserID):
|
||||||
|
user_id = UserID.from_string(user_id)
|
||||||
|
return Requester(user_id, access_token_id, is_guest, device_id)
|
||||||
|
|
||||||
|
|
||||||
def get_domain_from_id(string):
|
def get_domain_from_id(string):
|
||||||
|
|
|
@ -19,11 +19,12 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
|
|
||||||
|
import synapse.types
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.handlers.profile import ProfileHandler
|
from synapse.handlers.profile import ProfileHandler
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver, requester_for_user
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
class ProfileHandlers(object):
|
class ProfileHandlers(object):
|
||||||
|
@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
def test_set_my_name(self):
|
def test_set_my_name(self):
|
||||||
yield self.handler.set_displayname(
|
yield self.handler.set_displayname(
|
||||||
self.frank,
|
self.frank,
|
||||||
requester_for_user(self.frank),
|
synapse.types.create_requester(self.frank),
|
||||||
"Frank Jr."
|
"Frank Jr."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
def test_set_my_name_noauth(self):
|
def test_set_my_name_noauth(self):
|
||||||
d = self.handler.set_displayname(
|
d = self.handler.set_displayname(
|
||||||
self.frank,
|
self.frank,
|
||||||
requester_for_user(self.bob),
|
synapse.types.create_requester(self.bob),
|
||||||
"Frank Jr."
|
"Frank Jr."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_avatar(self):
|
def test_set_my_avatar(self):
|
||||||
yield self.handler.set_avatar_url(
|
yield self.handler.set_avatar_url(
|
||||||
self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
|
self.frank, synapse.types.create_requester(self.frank),
|
||||||
|
"http://my.server/pic.gif"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
|
|
|
@ -13,15 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.replication.resource import ReplicationResource
|
|
||||||
from synapse.types import Requester, UserID
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from tests import unittest
|
|
||||||
from tests.utils import setup_test_homeserver, requester_for_user
|
|
||||||
from mock import Mock, NonCallableMock
|
|
||||||
import json
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
from mock import Mock, NonCallableMock
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.types
|
||||||
|
from synapse.replication.resource import ReplicationResource
|
||||||
|
from synapse.types import UserID
|
||||||
|
from tests import unittest
|
||||||
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
class ReplicationResourceCase(unittest.TestCase):
|
class ReplicationResourceCase(unittest.TestCase):
|
||||||
|
@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
def test_events_and_state(self):
|
def test_events_and_state(self):
|
||||||
get = self.get(events="-1", state="-1", timeout="0")
|
get = self.get(events="-1", state="-1", timeout="0")
|
||||||
yield self.hs.get_handlers().room_creation_handler.create_room(
|
yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||||
Requester(self.user, "", False), {}
|
synapse.types.create_requester(self.user), {}
|
||||||
)
|
)
|
||||||
code, body = yield get
|
code, body = yield get
|
||||||
self.assertEquals(code, 200)
|
self.assertEquals(code, 200)
|
||||||
|
@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
def send_text_message(self, room_id, message):
|
def send_text_message(self, room_id, message):
|
||||||
handler = self.hs.get_handlers().message_handler
|
handler = self.hs.get_handlers().message_handler
|
||||||
event = yield handler.create_and_send_nonmember_event(
|
event = yield handler.create_and_send_nonmember_event(
|
||||||
requester_for_user(self.user),
|
synapse.types.create_requester(self.user),
|
||||||
{
|
{
|
||||||
"type": "m.room.message",
|
"type": "m.room.message",
|
||||||
"content": {"body": "message", "msgtype": "m.text"},
|
"content": {"body": "message", "msgtype": "m.text"},
|
||||||
|
@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_room(self):
|
def create_room(self):
|
||||||
result = yield self.hs.get_handlers().room_creation_handler.create_room(
|
result = yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||||
Requester(self.user, "", False), {}
|
synapse.types.create_requester(self.user), {}
|
||||||
)
|
)
|
||||||
defer.returnValue(result["room_id"])
|
defer.returnValue(result["room_id"])
|
||||||
|
|
||||||
|
|
|
@ -14,17 +14,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests REST events for /profile paths."""
|
"""Tests REST events for /profile paths."""
|
||||||
from tests import unittest
|
from mock import Mock
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock
|
import synapse.types
|
||||||
|
|
||||||
from ....utils import MockHttpResource, setup_test_homeserver
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.types import Requester, UserID
|
|
||||||
|
|
||||||
from synapse.rest.client.v1 import profile
|
from synapse.rest.client.v1 import profile
|
||||||
|
from tests import unittest
|
||||||
|
from ....utils import MockHttpResource, setup_test_homeserver
|
||||||
|
|
||||||
myid = "@1234ABCD:test"
|
myid = "@1234ABCD:test"
|
||||||
PATH_PREFIX = "/_matrix/client/api/v1"
|
PATH_PREFIX = "/_matrix/client/api/v1"
|
||||||
|
@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None, allow_guest=False):
|
def _get_user_by_req(request=None, allow_guest=False):
|
||||||
return Requester(UserID.from_string(myid), "", False)
|
return synapse.types.create_requester(myid)
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.federation.transport import server
|
from synapse.federation.transport import server
|
||||||
from synapse.types import Requester
|
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
@ -512,7 +511,3 @@ class DeferredMockCallable(object):
|
||||||
"call(%s)" % _format_call(c[0], c[1]) for c in calls
|
"call(%s)" % _format_call(c[0], c[1]) for c in calls
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def requester_for_user(user):
|
|
||||||
return Requester(user, None, False)
|
|
||||||
|
|
Loading…
Reference in New Issue