Pass client redirect URL into SAML mapping providers

This commit is contained in:
Richard van der Hoff 2020-01-11 01:01:53 +00:00
parent 47e63cc67a
commit dc69a1cf43
1 changed files with 11 additions and 4 deletions

View File

@ -114,10 +114,10 @@ class SamlHandler:
# the dict. # the dict.
self.expire_sessions() self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes) user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state) self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes): async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try: try:
saml2_auth = self._saml_client.parse_authn_request_response( saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes, resp_bytes,
@ -185,7 +185,7 @@ class SamlHandler:
# Map saml response to user attributes using the configured mapping provider # Map saml response to user attributes using the configured mapping provider
for i in range(1000): for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i saml2_auth, i, client_redirect_url=client_redirect_url,
) )
logger.debug( logger.debug(
@ -218,6 +218,8 @@ class SamlHandler:
500, "Unable to generate a Matrix ID from the SAML response" 500, "Unable to generate a Matrix ID from the SAML response"
) )
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayname localpart=localpart, default_display_name=displayname
) )
@ -278,7 +280,10 @@ class DefaultSamlMappingProvider(object):
self._mxid_mapper = parsed_config.mxid_mapper self._mxid_mapper = parsed_config.mxid_mapper
def saml_response_to_user_attributes( def saml_response_to_user_attributes(
self, saml_response: saml2.response.AuthnResponse, failures: int = 0, self,
saml_response: saml2.response.AuthnResponse,
failures: int,
client_redirect_url: str,
) -> dict: ) -> dict:
"""Maps some text from a SAML response to attributes of a new user """Maps some text from a SAML response to attributes of a new user
@ -288,6 +293,8 @@ class DefaultSamlMappingProvider(object):
failures: How many times a call to this function with this failures: How many times a call to this function with this
saml_response has resulted in a failure saml_response has resulted in a failure
client_redirect_url: where the client wants to redirect to
Returns: Returns:
dict: A dict containing new user attributes. Possible keys: dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid * mxid_localpart (str): Required. The localpart of the user's mxid