Merge pull request #1096 from matrix-org/markjh/get_access_token

Add helper function for getting access_tokens from requests
This commit is contained in:
Mark Haines 2016-09-09 17:09:27 +01:00 committed by GitHub
commit dbff7e9436
6 changed files with 67 additions and 27 deletions

View File

@ -583,12 +583,15 @@ class Auth(object):
"""
# Can optionally look elsewhere in the request (e.g. headers)
try:
user_id = yield self._get_appservice_user_id(request.args)
user_id = yield self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id))
access_token = request.args["access_token"][0]
access_token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
token_id = user_info["token_id"]
@ -629,17 +632,19 @@ class Auth(object):
)
@defer.inlineCallbacks
def _get_appservice_user_id(self, request_args):
def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token(
request_args["access_token"][0]
get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
if app_service is None:
defer.returnValue(None)
if "user_id" not in request_args:
if "user_id" not in request.args:
defer.returnValue(app_service.sender)
user_id = request_args["user_id"][0]
user_id = request.args["user_id"][0]
if app_service.sender == user_id:
defer.returnValue(app_service.sender)
@ -833,7 +838,9 @@ class Auth(object):
@defer.inlineCallbacks
def get_appservice_by_req(self, request):
try:
token = request.args["access_token"][0]
token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = yield self.store.get_app_service_by_token(token)
if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
@ -1142,3 +1149,40 @@ class Auth(object):
"This server requires you to be a moderator in the room to"
" edit its room list entry"
)
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
return bool(query_params)
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
query_params = request.args.get("access_token")
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]

View File

@ -15,7 +15,7 @@
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes
from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns
@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
try:
access_token = request.args["access_token"][0]
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
access_token = get_access_token_from_request(request)
yield self.store.delete_access_token(access_token)
defer.returnValue((200, {}))

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType
from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request
@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
if "access_token" not in request.args:
raise SynapseError(400, "Expected application service token.")
as_token = get_access_token_from_request(request)
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
as_token = request.args["access_token"][0]
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
if "access_token" not in request.args:
raise SynapseError(400, "Expected application service token.")
access_token = get_access_token_from_request(request)
app_service = yield self.store.get_app_service_by_token(
request.args["access_token"][0]
access_token
)
if not app_service:
raise SynapseError(403, "Invalid application service token.")

View File

@ -17,6 +17,8 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
logger = logging.getLogger(__name__)
@ -90,6 +92,6 @@ class HttpTransactionStore(object):
return response
def _get_key(self, request):
token = request.args["access_token"][0]
token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token

View File

@ -15,6 +15,7 @@
from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
if 'access_token' in request.args:
if has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0], body
desired_username, access_token, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return

View File

@ -80,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
@ -104,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields