Merge pull request #995 from matrix-org/rav/clean_up_cas_login

Clean up CAS login code
This commit is contained in:
David Baker 2016-08-09 10:21:56 +01:00 committed by GitHub
commit cd41c6ece2
1 changed files with 33 additions and 125 deletions

View File

@ -54,10 +54,6 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler() self.device_handler = self.hs.get_device_handler()
@ -110,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE): LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission) result = yield self.do_jwt_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield self.http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission) result = yield self.do_token_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
@ -191,51 +176,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id
)
)
result = {
"user_id": registered_user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def do_jwt_login(self, login_submission): def do_jwt_login(self, login_submission):
token = login_submission.get("token", None) token = login_submission.get("token", None)
@ -293,33 +233,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def _register_device(self, user_id, login_submission): def _register_device(self, user_id, login_submission):
"""Register a device for a user. """Register a device for a user.
@ -384,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet): class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@ -480,30 +381,39 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts) return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body) user = None
if not root.tag.endswith("serviceResponse"): attributes = None
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) try:
if not root[0].tag.endswith("authenticationSuccess"): root = ET.fromstring(cas_response_body)
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) if not root.tag.endswith("serviceResponse"):
for child in root[0]: raise Exception("root of CAS response is not serviceResponse")
if child.tag.endswith("user"): success = (root[0].tag.endswith("authenticationSuccess"))
user = child.text for child in root[0]:
if child.tag.endswith("attributes"): if child.tag.endswith("user"):
attributes = {} user = child.text
for attribute in child: if child.tag.endswith("attributes"):
# ElementTree library expands the namespace in attribute tags attributes = {}
# to the full URL of the namespace. for attribute in child:
# See (https://docs.python.org/2/library/xml.etree.elementtree.html) # ElementTree library expands the namespace in
# We don't care about namespace here and it will always be encased in # attribute tags to the full URL of the namespace.
# curly braces, so we remove them. # We don't care about namespace here and it will always
if "}" in attribute.tag: # be encased in curly braces, so we remove them.
attributes[attribute.tag.split("}")[1]] = attribute.text tag = attribute.tag
else: if "}" in tag:
attributes[attribute.tag] = attribute.text tag = tag.split("}")[1]
if user is None or attributes is None: attributes[tag] = attribute.text
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) if user is None:
raise Exception("CAS response does not contain user")
return (user, attributes) if attributes is None:
raise Exception("CAS response does not contain attributes")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
return user, attributes
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
@ -513,5 +423,3 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server) CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)