Merge pull request #1096 from matrix-org/markjh/get_access_token
Add helper function for getting access_tokens from requests
This commit is contained in:
commit
dbff7e9436
|
@ -583,12 +583,15 @@ class Auth(object):
|
|||
"""
|
||||
# Can optionally look elsewhere in the request (e.g. headers)
|
||||
try:
|
||||
user_id = yield self._get_appservice_user_id(request.args)
|
||||
user_id = yield self._get_appservice_user_id(request)
|
||||
if user_id:
|
||||
request.authenticated_entity = user_id
|
||||
defer.returnValue(synapse.types.create_requester(user_id))
|
||||
|
||||
access_token = request.args["access_token"][0]
|
||||
access_token = get_access_token_from_request(
|
||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||
)
|
||||
|
||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
|
@ -629,17 +632,19 @@ class Auth(object):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_appservice_user_id(self, request_args):
|
||||
def _get_appservice_user_id(self, request):
|
||||
app_service = yield self.store.get_app_service_by_token(
|
||||
request_args["access_token"][0]
|
||||
get_access_token_from_request(
|
||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||
)
|
||||
)
|
||||
if app_service is None:
|
||||
defer.returnValue(None)
|
||||
|
||||
if "user_id" not in request_args:
|
||||
if "user_id" not in request.args:
|
||||
defer.returnValue(app_service.sender)
|
||||
|
||||
user_id = request_args["user_id"][0]
|
||||
user_id = request.args["user_id"][0]
|
||||
if app_service.sender == user_id:
|
||||
defer.returnValue(app_service.sender)
|
||||
|
||||
|
@ -833,7 +838,9 @@ class Auth(object):
|
|||
@defer.inlineCallbacks
|
||||
def get_appservice_by_req(self, request):
|
||||
try:
|
||||
token = request.args["access_token"][0]
|
||||
token = get_access_token_from_request(
|
||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||
)
|
||||
service = yield self.store.get_app_service_by_token(token)
|
||||
if not service:
|
||||
logger.warn("Unrecognised appservice access token: %s" % (token,))
|
||||
|
@ -1142,3 +1149,40 @@ class Auth(object):
|
|||
"This server requires you to be a moderator in the room to"
|
||||
" edit its room list entry"
|
||||
)
|
||||
|
||||
|
||||
def has_access_token(request):
|
||||
"""Checks if the request has an access_token.
|
||||
|
||||
Returns:
|
||||
bool: False if no access_token was given, True otherwise.
|
||||
"""
|
||||
query_params = request.args.get("access_token")
|
||||
return bool(query_params)
|
||||
|
||||
|
||||
def get_access_token_from_request(request, token_not_found_http_status=401):
|
||||
"""Extracts the access_token from the request.
|
||||
|
||||
Args:
|
||||
request: The http request.
|
||||
token_not_found_http_status(int): The HTTP status code to set in the
|
||||
AuthError if the token isn't found. This is used in some of the
|
||||
legacy APIs to change the status code to 403 from the default of
|
||||
401 since some of the old clients depended on auth errors returning
|
||||
403.
|
||||
Returns:
|
||||
str: The access_token
|
||||
Raises:
|
||||
AuthError: If there isn't an access_token in the request.
|
||||
"""
|
||||
query_params = request.args.get("access_token")
|
||||
# Try to get the access_token from the query params.
|
||||
if not query_params:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
|
||||
return query_params[0]
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import AuthError, Codes
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
||||
|
@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
try:
|
||||
access_token = request.args["access_token"][0]
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
access_token = get_access_token_from_request(request)
|
||||
yield self.store.delete_access_token(access_token)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _do_app_service(self, request, register_json, session):
|
||||
if "access_token" not in request.args:
|
||||
raise SynapseError(400, "Expected application service token.")
|
||||
as_token = get_access_token_from_request(request)
|
||||
|
||||
if "user" not in register_json:
|
||||
raise SynapseError(400, "Expected 'user' key.")
|
||||
|
||||
as_token = request.args["access_token"][0]
|
||||
user_localpart = register_json["user"].encode("utf-8")
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
|
@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
def on_POST(self, request):
|
||||
user_json = parse_json_object_from_request(request)
|
||||
|
||||
if "access_token" not in request.args:
|
||||
raise SynapseError(400, "Expected application service token.")
|
||||
|
||||
access_token = get_access_token_from_request(request)
|
||||
app_service = yield self.store.get_app_service_by_token(
|
||||
request.args["access_token"][0]
|
||||
access_token
|
||||
)
|
||||
if not app_service:
|
||||
raise SynapseError(403, "Invalid application service token.")
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
to ensure idempotency when performing PUTs using the REST API."""
|
||||
import logging
|
||||
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -90,6 +92,6 @@ class HttpTransactionStore(object):
|
|||
return response
|
||||
|
||||
def _get_key(self, request):
|
||||
token = request.args["access_token"][0]
|
||||
token = get_access_token_from_request(request)
|
||||
path_without_txn_id = request.path.rsplit("/", 1)[0]
|
||||
return path_without_txn_id + "/" + token
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
|
@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
|
|||
desired_username = body['username']
|
||||
|
||||
appservice = None
|
||||
if 'access_token' in request.args:
|
||||
if has_access_token(request):
|
||||
appservice = yield self.auth.get_appservice_by_req(request)
|
||||
|
||||
# fork off as soon as possible for ASes and shared secret auth which
|
||||
|
@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
|
|||
# 'user' key not 'username'). Since this is a new addition, we'll
|
||||
# fallback to 'username' if they gave one.
|
||||
desired_username = body.get("user", desired_username)
|
||||
access_token = get_access_token_from_request(request)
|
||||
|
||||
if isinstance(desired_username, basestring):
|
||||
result = yield self._do_appservice_registration(
|
||||
desired_username, request.args["access_token"][0], body
|
||||
desired_username, access_token, body
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
|
|
@ -80,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
|
|||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
fields.pop("access_token", None)
|
||||
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.USER, protocol, fields
|
||||
|
@ -104,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
|
|||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
fields.pop("access_token", None)
|
||||
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.LOCATION, protocol, fields
|
||||
|
|
Loading…
Reference in New Issue