Change the way we specify if we require auth or not

This commit is contained in:
Erik Johnston 2016-08-05 16:17:04 +01:00
parent 32fc39fd4c
commit 597c79be10
1 changed files with 56 additions and 41 deletions

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request, parse_string from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
@ -60,6 +60,16 @@ class TransportLayerServer(JsonResource):
) )
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -67,7 +77,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request, content):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@ -75,17 +85,10 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
content = None if content is not None:
origin = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON) origin = None
def parse_auth_header(header_str): def parse_auth_header(header_str):
try: try:
@ -103,14 +106,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return (origin, key, sig) return (origin, key, sig)
except: except:
raise SynapseError( raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -121,7 +124,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]: if not json_request["signatures"]:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -130,10 +133,12 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
defer.returnValue((origin, content)) defer.returnValue(origin)
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name, def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler): room_list_handler):
self.handler = handler self.handler = handler
@ -141,29 +146,46 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler self.room_list_handler = room_list_handler
def _wrap(self, code): def _wrap(self, func):
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(code) @functools.wraps(func)
def new_code(request, *args, **kwargs): def new_func(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try: try:
(origin, content) = yield authenticator.authenticate_request(request) origin = yield authenticator.authenticate_request(request, content)
with ratelimiter.ratelimit(origin) as d: except NoAuthenticationError:
yield d origin = None
response = yield code( if self.REQUIRE_AUTH:
origin, content, request.args, *args, **kwargs logger.exception("authenticate_request failed")
) raise
except: except:
logger.exception("authenticate_request failed") logger.exception("authenticate_request failed")
raise raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response) defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish # Extra logic that functools.wraps() doesn't finish
new_code.__self__ = code.__self__ new_func.__self__ = func.__self__
return new_code return new_func
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$") pattern = re.compile("^" + PREFIX + self.PATH + "$")
@ -429,9 +451,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, origin, content, query):
content = parse_json_object_from_request(request)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
@ -453,11 +476,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception raise last_exception
defer.returnValue((200, {})) defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet): class OpenIdUserInfo(BaseFederationServlet):
""" """
@ -478,9 +496,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo" PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, origin, content, query):
token = parse_string(request, "access_token") token = query.get("access_token", [None])[0]
if token is None: if token is None:
defer.returnValue((401, { defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required" "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@ -497,11 +517,6 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id})) defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet): class PublicRoomList(BaseFederationServlet):
""" """