Parse both user and attributes from CAS response
This commit is contained in:
parent
782f7fb489
commit
7845f62c22
|
@ -125,6 +125,34 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_cas_login(self, cas_response_body):
|
def do_cas_login(self, cas_response_body):
|
||||||
|
(user, attributes) = self.parse_cas_response(cas_response_body)
|
||||||
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
|
auth_handler = self.handlers.auth_handler
|
||||||
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
|
if user_exists:
|
||||||
|
user_id, access_token, refresh_token = (
|
||||||
|
yield auth_handler.login_with_cas_user_id(user_id)
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"user_id": user_id, # may have changed
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
user_id, access_token = (
|
||||||
|
yield self.handlers.registration_handler.register(localpart=user)
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"user_id": user_id, # may have changed
|
||||||
|
"access_token": access_token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
|
def parse_cas_response(self, cas_response_body):
|
||||||
root = ET.fromstring(cas_response_body)
|
root = ET.fromstring(cas_response_body)
|
||||||
if not root.tag.endswith("serviceResponse"):
|
if not root.tag.endswith("serviceResponse"):
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||||
|
@ -133,33 +161,17 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
for child in root[0]:
|
for child in root[0]:
|
||||||
if child.tag.endswith("user"):
|
if child.tag.endswith("user"):
|
||||||
user = child.text
|
user = child.text
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
if child.tag.endswith("attributes"):
|
||||||
auth_handler = self.handlers.auth_handler
|
attributes = {}
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
for attribute in child:
|
||||||
if user_exists:
|
if "}" in attribute.tag:
|
||||||
user_id, access_token, refresh_token = (
|
attributes[attribute.tag.split("}")[1]] = attribute.text
|
||||||
yield auth_handler.login_with_cas_user_id(user_id)
|
else:
|
||||||
)
|
attributes[attribute.tag] = attribute.text
|
||||||
result = {
|
if user is None or attributes is None:
|
||||||
"user_id": user_id, # may have changed
|
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||||
"access_token": access_token,
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
return (user, attributes)
|
||||||
user_id, access_token = (
|
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"user_id": user_id, # may have changed
|
|
||||||
"access_token": access_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
|
||||||
|
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginFallbackRestServlet(ClientV1RestServlet):
|
class LoginFallbackRestServlet(ClientV1RestServlet):
|
||||||
|
|
Loading…
Reference in New Issue