Merge pull request #995 from matrix-org/rav/clean_up_cas_login
Clean up CAS login code
This commit is contained in:
commit
cd41c6ece2
|
@ -54,10 +54,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
self.jwt_secret = hs.config.jwt_secret
|
self.jwt_secret = hs.config.jwt_secret
|
||||||
self.jwt_algorithm = hs.config.jwt_algorithm
|
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||||
self.cas_enabled = hs.config.cas_enabled
|
self.cas_enabled = hs.config.cas_enabled
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
|
||||||
self.servername = hs.config.server_name
|
|
||||||
self.http_client = hs.get_simple_http_client()
|
|
||||||
self.auth_handler = self.hs.get_auth_handler()
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
self.device_handler = self.hs.get_device_handler()
|
self.device_handler = self.hs.get_device_handler()
|
||||||
|
|
||||||
|
@ -110,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
LoginRestServlet.JWT_TYPE):
|
LoginRestServlet.JWT_TYPE):
|
||||||
result = yield self.do_jwt_login(login_submission)
|
result = yield self.do_jwt_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
elif self.cas_enabled and (login_submission["type"] ==
|
|
||||||
LoginRestServlet.CAS_TYPE):
|
|
||||||
uri = "%s/proxyValidate" % (self.cas_server_url,)
|
|
||||||
args = {
|
|
||||||
"ticket": login_submission["ticket"],
|
|
||||||
"service": login_submission["service"]
|
|
||||||
}
|
|
||||||
body = yield self.http_client.get_raw(uri, args)
|
|
||||||
result = yield self.do_cas_login(body)
|
|
||||||
defer.returnValue(result)
|
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
result = yield self.do_token_login(login_submission)
|
result = yield self.do_token_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
@ -191,51 +176,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def do_cas_login(self, cas_response_body):
|
|
||||||
user, attributes = self.parse_cas_response(cas_response_body)
|
|
||||||
|
|
||||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
|
||||||
# If required attribute was not in CAS Response - Forbidden
|
|
||||||
if required_attribute not in attributes:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
# Also need to check value
|
|
||||||
if required_value is not None:
|
|
||||||
actual_value = attributes[required_attribute]
|
|
||||||
# If required attribute value does not match expected - Forbidden
|
|
||||||
if required_value != actual_value:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
|
||||||
auth_handler = self.auth_handler
|
|
||||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
|
||||||
if registered_user_id:
|
|
||||||
access_token, refresh_token = (
|
|
||||||
yield auth_handler.get_login_tuple_for_user_id(
|
|
||||||
registered_user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"user_id": registered_user_id, # may have changed
|
|
||||||
"access_token": access_token,
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
user_id, access_token = (
|
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"user_id": user_id, # may have changed
|
|
||||||
"access_token": access_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_jwt_login(self, login_submission):
|
def do_jwt_login(self, login_submission):
|
||||||
token = login_submission.get("token", None)
|
token = login_submission.get("token", None)
|
||||||
|
@ -293,33 +233,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
|
||||||
root = ET.fromstring(cas_response_body)
|
|
||||||
if not root.tag.endswith("serviceResponse"):
|
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
if not root[0].tag.endswith("authenticationSuccess"):
|
|
||||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
for child in root[0]:
|
|
||||||
if child.tag.endswith("user"):
|
|
||||||
user = child.text
|
|
||||||
if child.tag.endswith("attributes"):
|
|
||||||
attributes = {}
|
|
||||||
for attribute in child:
|
|
||||||
# ElementTree library expands the namespace in attribute tags
|
|
||||||
# to the full URL of the namespace.
|
|
||||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
|
||||||
# We don't care about namespace here and it will always be encased in
|
|
||||||
# curly braces, so we remove them.
|
|
||||||
if "}" in attribute.tag:
|
|
||||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
|
||||||
else:
|
|
||||||
attributes[attribute.tag] = attribute.text
|
|
||||||
if user is None or attributes is None:
|
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
return (user, attributes)
|
|
||||||
|
|
||||||
def _register_device(self, user_id, login_submission):
|
def _register_device(self, user_id, login_submission):
|
||||||
"""Register a device for a user.
|
"""Register a device for a user.
|
||||||
|
|
||||||
|
@ -384,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet):
|
||||||
defer.returnValue((200, {"status": "not_authenticated"}))
|
defer.returnValue((200, {"status": "not_authenticated"}))
|
||||||
|
|
||||||
|
|
||||||
# TODO Delete this after all CAS clients switch to token login instead
|
|
||||||
class CasRestServlet(ClientV1RestServlet):
|
|
||||||
PATTERNS = client_path_patterns("/login/cas", releases=())
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
super(CasRestServlet, self).__init__(hs)
|
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
|
||||||
|
|
||||||
def on_GET(self, request):
|
|
||||||
return (200, {"serverUrl": self.cas_server_url})
|
|
||||||
|
|
||||||
|
|
||||||
class CasRedirectServlet(ClientV1RestServlet):
|
class CasRedirectServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
|
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
|
||||||
|
|
||||||
|
@ -480,30 +381,39 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
return urlparse.urlunparse(url_parts)
|
return urlparse.urlunparse(url_parts)
|
||||||
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
def parse_cas_response(self, cas_response_body):
|
||||||
root = ET.fromstring(cas_response_body)
|
user = None
|
||||||
if not root.tag.endswith("serviceResponse"):
|
attributes = None
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
try:
|
||||||
if not root[0].tag.endswith("authenticationSuccess"):
|
root = ET.fromstring(cas_response_body)
|
||||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
if not root.tag.endswith("serviceResponse"):
|
||||||
for child in root[0]:
|
raise Exception("root of CAS response is not serviceResponse")
|
||||||
if child.tag.endswith("user"):
|
success = (root[0].tag.endswith("authenticationSuccess"))
|
||||||
user = child.text
|
for child in root[0]:
|
||||||
if child.tag.endswith("attributes"):
|
if child.tag.endswith("user"):
|
||||||
attributes = {}
|
user = child.text
|
||||||
for attribute in child:
|
if child.tag.endswith("attributes"):
|
||||||
# ElementTree library expands the namespace in attribute tags
|
attributes = {}
|
||||||
# to the full URL of the namespace.
|
for attribute in child:
|
||||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
# ElementTree library expands the namespace in
|
||||||
# We don't care about namespace here and it will always be encased in
|
# attribute tags to the full URL of the namespace.
|
||||||
# curly braces, so we remove them.
|
# We don't care about namespace here and it will always
|
||||||
if "}" in attribute.tag:
|
# be encased in curly braces, so we remove them.
|
||||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
tag = attribute.tag
|
||||||
else:
|
if "}" in tag:
|
||||||
attributes[attribute.tag] = attribute.text
|
tag = tag.split("}")[1]
|
||||||
if user is None or attributes is None:
|
attributes[tag] = attribute.text
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
if user is None:
|
||||||
|
raise Exception("CAS response does not contain user")
|
||||||
return (user, attributes)
|
if attributes is None:
|
||||||
|
raise Exception("CAS response does not contain attributes")
|
||||||
|
except Exception:
|
||||||
|
logger.error("Error parsing CAS response", exc_info=1)
|
||||||
|
raise LoginError(401, "Invalid CAS response",
|
||||||
|
errcode=Codes.UNAUTHORIZED)
|
||||||
|
if not success:
|
||||||
|
raise LoginError(401, "Unsuccessful CAS response",
|
||||||
|
errcode=Codes.UNAUTHORIZED)
|
||||||
|
return user, attributes
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
@ -513,5 +423,3 @@ def register_servlets(hs, http_server):
|
||||||
if hs.config.cas_enabled:
|
if hs.config.cas_enabled:
|
||||||
CasRedirectServlet(hs).register(http_server)
|
CasRedirectServlet(hs).register(http_server)
|
||||||
CasTicketServlet(hs).register(http_server)
|
CasTicketServlet(hs).register(http_server)
|
||||||
CasRestServlet(hs).register(http_server)
|
|
||||||
# TODO PasswordResetRestServlet(hs).register(http_server)
|
|
||||||
|
|
Loading…
Reference in New Issue