Parse both user and attributes from CAS response

This commit is contained in:
Steven Hammerton 2015-10-12 10:52:43 +01:00
parent 782f7fb489
commit 7845f62c22
1 changed files with 38 additions and 26 deletions

View File

@ -125,6 +125,34 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_cas_login(self, cas_response_body): def do_cas_login(self, cas_response_body):
(user, attributes) = self.parse_cas_response(cas_response_body)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.login_with_cas_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body) root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"): if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
@ -133,33 +161,17 @@ class LoginRestServlet(ClientV1RestServlet):
for child in root[0]: for child in root[0]:
if child.tag.endswith("user"): if child.tag.endswith("user"):
user = child.text user = child.text
user_id = UserID.create(user, self.hs.hostname).to_string() if child.tag.endswith("attributes"):
auth_handler = self.handlers.auth_handler attributes = {}
user_exists = yield auth_handler.does_user_exist(user_id) for attribute in child:
if user_exists: if "}" in attribute.tag:
user_id, access_token, refresh_token = ( attributes[attribute.tag.split("}")[1]] = attribute.text
yield auth_handler.login_with_cas_user_id(user_id) else:
) attributes[attribute.tag] = attribute.text
result = { if user is None or attributes is None:
"user_id": user_id, # may have changed raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else: return (user, attributes)
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
class LoginFallbackRestServlet(ClientV1RestServlet): class LoginFallbackRestServlet(ClientV1RestServlet):