Allow additional SSO properties to be passed to the client (#8413)
This commit is contained in:
parent
ceafb5a1c6
commit
8b40843392
|
@ -0,0 +1 @@
|
||||||
|
Support passing additional single sign-on parameters to the client.
|
|
@ -1748,6 +1748,14 @@ oidc_config:
|
||||||
#
|
#
|
||||||
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
|
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
|
||||||
|
|
||||||
|
# Jinja2 templates for extra attributes to send back to the client during
|
||||||
|
# login.
|
||||||
|
#
|
||||||
|
# Note that these are non-standard and clients will ignore them without modifications.
|
||||||
|
#
|
||||||
|
#extra_attributes:
|
||||||
|
#birthdate: "{{ user.birthdate }}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Enable CAS for registration and login.
|
# Enable CAS for registration and login.
|
||||||
|
|
|
@ -57,7 +57,7 @@ A custom mapping provider must specify the following methods:
|
||||||
- This method must return a string, which is the unique identifier for the
|
- This method must return a string, which is the unique identifier for the
|
||||||
user. Commonly the ``sub`` claim of the response.
|
user. Commonly the ``sub`` claim of the response.
|
||||||
* `map_user_attributes(self, userinfo, token)`
|
* `map_user_attributes(self, userinfo, token)`
|
||||||
- This method should be async.
|
- This method must be async.
|
||||||
- Arguments:
|
- Arguments:
|
||||||
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||||
information from.
|
information from.
|
||||||
|
@ -66,6 +66,18 @@ A custom mapping provider must specify the following methods:
|
||||||
- Returns a dictionary with two keys:
|
- Returns a dictionary with two keys:
|
||||||
- localpart: A required string, used to generate the Matrix ID.
|
- localpart: A required string, used to generate the Matrix ID.
|
||||||
- displayname: An optional string, the display name for the user.
|
- displayname: An optional string, the display name for the user.
|
||||||
|
* `get_extra_attributes(self, userinfo, token)`
|
||||||
|
- This method must be async.
|
||||||
|
- Arguments:
|
||||||
|
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||||
|
information from.
|
||||||
|
- `token` - A dictionary which includes information necessary to make
|
||||||
|
further requests to the OpenID provider.
|
||||||
|
- Returns a dictionary that is suitable to be serialized to JSON. This
|
||||||
|
will be returned as part of the response during a successful login.
|
||||||
|
|
||||||
|
Note that care should be taken to not overwrite any of the parameters
|
||||||
|
usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
|
||||||
|
|
||||||
### Default OpenID Mapping Provider
|
### Default OpenID Mapping Provider
|
||||||
|
|
||||||
|
|
|
@ -243,6 +243,22 @@ for the room are in flight:
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
|
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
|
||||||
|
|
||||||
|
Additionally, the following endpoints should be included if Synapse is configured
|
||||||
|
to use SSO (you only need to include the ones for whichever SSO provider you're
|
||||||
|
using):
|
||||||
|
|
||||||
|
# OpenID Connect requests.
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
|
||||||
|
^/_synapse/oidc/callback$
|
||||||
|
|
||||||
|
# SAML requests.
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
|
||||||
|
^/_matrix/saml2/authn_response$
|
||||||
|
|
||||||
|
# CAS requests.
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
|
||||||
|
|
||||||
Note that a HTTP listener with `client` and `federation` resources must be
|
Note that a HTTP listener with `client` and `federation` resources must be
|
||||||
configured in the `worker_listeners` option in the worker config.
|
configured in the `worker_listeners` option in the worker config.
|
||||||
|
|
||||||
|
|
|
@ -204,6 +204,14 @@ class OIDCConfig(Config):
|
||||||
# If unset, no displayname will be set.
|
# If unset, no displayname will be set.
|
||||||
#
|
#
|
||||||
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
|
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
|
||||||
|
|
||||||
|
# Jinja2 templates for extra attributes to send back to the client during
|
||||||
|
# login.
|
||||||
|
#
|
||||||
|
# Note that these are non-standard and clients will ignore them without modifications.
|
||||||
|
#
|
||||||
|
#extra_attributes:
|
||||||
|
#birthdate: "{{{{ user.birthdate }}}}"
|
||||||
""".format(
|
""".format(
|
||||||
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
|
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
|
||||||
)
|
)
|
||||||
|
|
|
@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class SsoLoginExtraAttributes:
|
||||||
|
"""Data we track about SAML2 sessions"""
|
||||||
|
|
||||||
|
# time the session was created, in milliseconds
|
||||||
|
creation_time = attr.ib(type=int)
|
||||||
|
extra_attributes = attr.ib(type=JsonDict)
|
||||||
|
|
||||||
|
|
||||||
class AuthHandler(BaseHandler):
|
class AuthHandler(BaseHandler):
|
||||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
@ -239,6 +248,10 @@ class AuthHandler(BaseHandler):
|
||||||
# cast to tuple for use with str.startswith
|
# cast to tuple for use with str.startswith
|
||||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||||
|
|
||||||
|
# A mapping of user ID to extra attributes to include in the login
|
||||||
|
# response.
|
||||||
|
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
|
||||||
|
|
||||||
async def validate_user_via_ui_auth(
|
async def validate_user_via_ui_auth(
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
|
@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler):
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
):
|
):
|
||||||
"""Having figured out a mxid for this user, complete the HTTP request
|
"""Having figured out a mxid for this user, complete the HTTP request
|
||||||
|
|
||||||
|
@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler):
|
||||||
request: The request to complete.
|
request: The request to complete.
|
||||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
client_redirect_url: The URL to which to redirect the user at the end of the
|
||||||
process.
|
process.
|
||||||
|
extra_attributes: Extra attributes which will be passed to the client
|
||||||
|
during successful login. Must be JSON serializable.
|
||||||
"""
|
"""
|
||||||
# If the account has been deactivated, do not proceed with the login
|
# If the account has been deactivated, do not proceed with the login
|
||||||
# flow.
|
# flow.
|
||||||
|
@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler):
|
||||||
respond_with_html(request, 403, self._sso_account_deactivated_template)
|
respond_with_html(request, 403, self._sso_account_deactivated_template)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._complete_sso_login(registered_user_id, request, client_redirect_url)
|
self._complete_sso_login(
|
||||||
|
registered_user_id, request, client_redirect_url, extra_attributes
|
||||||
|
)
|
||||||
|
|
||||||
def _complete_sso_login(
|
def _complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The synchronous portion of complete_sso_login.
|
The synchronous portion of complete_sso_login.
|
||||||
|
|
||||||
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
|
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
|
||||||
"""
|
"""
|
||||||
|
# Store any extra attributes which will be passed in the login response.
|
||||||
|
# Note that this is per-user so it may overwrite a previous value, this
|
||||||
|
# is considered OK since the newest SSO attributes should be most valid.
|
||||||
|
if extra_attributes:
|
||||||
|
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
|
||||||
|
self._clock.time_msec(), extra_attributes,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a login token
|
# Create a login token
|
||||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||||
registered_user_id
|
registered_user_id
|
||||||
|
@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
respond_with_html(request, 200, html)
|
respond_with_html(request, 200, html)
|
||||||
|
|
||||||
|
async def _sso_login_callback(self, login_result: JsonDict) -> None:
|
||||||
|
"""
|
||||||
|
A login callback which might add additional attributes to the login response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
login_result: The data to be sent to the client. Includes the user
|
||||||
|
ID and access token.
|
||||||
|
"""
|
||||||
|
# Expire attributes before processing. Note that there shouldn't be any
|
||||||
|
# valid logins that still have extra attributes.
|
||||||
|
self._expire_sso_extra_attributes()
|
||||||
|
|
||||||
|
extra_attributes = self._extra_attributes.get(login_result["user_id"])
|
||||||
|
if extra_attributes:
|
||||||
|
login_result.update(extra_attributes.extra_attributes)
|
||||||
|
|
||||||
|
def _expire_sso_extra_attributes(self) -> None:
|
||||||
|
"""
|
||||||
|
Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid.
|
||||||
|
"""
|
||||||
|
# TODO This should match the amount of time the macaroon is valid for.
|
||||||
|
LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000
|
||||||
|
expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME
|
||||||
|
to_expire = set()
|
||||||
|
for user_id, data in self._extra_attributes.items():
|
||||||
|
if data.creation_time < expire_before:
|
||||||
|
to_expire.add(user_id)
|
||||||
|
for user_id in to_expire:
|
||||||
|
logger.debug("Expiring extra attributes for user %s", user_id)
|
||||||
|
del self._extra_attributes[user_id]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_query_param_to_url(url: str, param_name: str, param: Any):
|
def add_query_param_to_url(url: str, param_name: str, param: Any):
|
||||||
url_parts = list(urllib.parse.urlparse(url))
|
url_parts = list(urllib.parse.urlparse(url))
|
||||||
|
|
|
@ -37,7 +37,7 @@ from synapse.config import ConfigError
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -707,6 +707,15 @@ class OidcHandler:
|
||||||
self._render_error(request, "mapping_error", str(e))
|
self._render_error(request, "mapping_error", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Mapping providers might not have get_extra_attributes: only call this
|
||||||
|
# method if it exists.
|
||||||
|
extra_attributes = None
|
||||||
|
get_extra_attributes = getattr(
|
||||||
|
self._user_mapping_provider, "get_extra_attributes", None
|
||||||
|
)
|
||||||
|
if get_extra_attributes:
|
||||||
|
extra_attributes = await get_extra_attributes(userinfo, token)
|
||||||
|
|
||||||
# and finally complete the login
|
# and finally complete the login
|
||||||
if ui_auth_session_id:
|
if ui_auth_session_id:
|
||||||
await self._auth_handler.complete_sso_ui_auth(
|
await self._auth_handler.complete_sso_ui_auth(
|
||||||
|
@ -714,7 +723,7 @@ class OidcHandler:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._auth_handler.complete_sso_login(
|
await self._auth_handler.complete_sso_login(
|
||||||
user_id, request, client_redirect_url
|
user_id, request, client_redirect_url, extra_attributes
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_oidc_session_token(
|
def _generate_oidc_session_token(
|
||||||
|
@ -984,7 +993,7 @@ class OidcMappingProvider(Generic[C]):
|
||||||
async def map_user_attributes(
|
async def map_user_attributes(
|
||||||
self, userinfo: UserInfo, token: Token
|
self, userinfo: UserInfo, token: Token
|
||||||
) -> UserAttribute:
|
) -> UserAttribute:
|
||||||
"""Map a ``UserInfo`` objects into user attributes.
|
"""Map a `UserInfo` object into user attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
userinfo: An object representing the user given by the OIDC provider
|
userinfo: An object representing the user given by the OIDC provider
|
||||||
|
@ -995,6 +1004,18 @@ class OidcMappingProvider(Generic[C]):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||||
|
"""Map a `UserInfo` object into additional attributes passed to the client during login.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userinfo: An object representing the user given by the OIDC provider
|
||||||
|
token: A dict with the tokens returned by the provider
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict containing additional attributes. Must be JSON serializable.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
# Used to clear out "None" values in templates
|
# Used to clear out "None" values in templates
|
||||||
def jinja_finalize(thing):
|
def jinja_finalize(thing):
|
||||||
|
@ -1009,6 +1030,7 @@ class JinjaOidcMappingConfig:
|
||||||
subject_claim = attr.ib() # type: str
|
subject_claim = attr.ib() # type: str
|
||||||
localpart_template = attr.ib() # type: Template
|
localpart_template = attr.ib() # type: Template
|
||||||
display_name_template = attr.ib() # type: Optional[Template]
|
display_name_template = attr.ib() # type: Optional[Template]
|
||||||
|
extra_attributes = attr.ib() # type: Dict[str, Template]
|
||||||
|
|
||||||
|
|
||||||
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||||
|
@ -1047,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||||
% (e,)
|
% (e,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_attributes = {} # type Dict[str, Template]
|
||||||
|
if "extra_attributes" in config:
|
||||||
|
extra_attributes_config = config.get("extra_attributes") or {}
|
||||||
|
if not isinstance(extra_attributes_config, dict):
|
||||||
|
raise ConfigError(
|
||||||
|
"oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in extra_attributes_config.items():
|
||||||
|
try:
|
||||||
|
extra_attributes[key] = env.from_string(value)
|
||||||
|
except Exception as e:
|
||||||
|
raise ConfigError(
|
||||||
|
"invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
|
||||||
|
% (key, e)
|
||||||
|
)
|
||||||
|
|
||||||
return JinjaOidcMappingConfig(
|
return JinjaOidcMappingConfig(
|
||||||
subject_claim=subject_claim,
|
subject_claim=subject_claim,
|
||||||
localpart_template=localpart_template,
|
localpart_template=localpart_template,
|
||||||
display_name_template=display_name_template,
|
display_name_template=display_name_template,
|
||||||
|
extra_attributes=extra_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_remote_user_id(self, userinfo: UserInfo) -> str:
|
def get_remote_user_id(self, userinfo: UserInfo) -> str:
|
||||||
|
@ -1071,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||||
display_name = None
|
display_name = None
|
||||||
|
|
||||||
return UserAttribute(localpart=localpart, display_name=display_name)
|
return UserAttribute(localpart=localpart, display_name=display_name)
|
||||||
|
|
||||||
|
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||||
|
extras = {} # type: Dict[str, str]
|
||||||
|
for key, template in self._config.extra_attributes.items():
|
||||||
|
try:
|
||||||
|
extras[key] = template.render(user=userinfo).strip()
|
||||||
|
except Exception as e:
|
||||||
|
# Log an error and skip this value (don't break login for this).
|
||||||
|
logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
|
||||||
|
return extras
|
||||||
|
|
|
@ -284,9 +284,7 @@ class LoginRestServlet(RestServlet):
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
login_submission: JsonDict,
|
login_submission: JsonDict,
|
||||||
callback: Optional[
|
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
||||||
Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
|
|
||||||
] = None,
|
|
||||||
create_non_existent_users: bool = False,
|
create_non_existent_users: bool = False,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
|
@ -299,12 +297,12 @@ class LoginRestServlet(RestServlet):
|
||||||
Args:
|
Args:
|
||||||
user_id: ID of the user to register.
|
user_id: ID of the user to register.
|
||||||
login_submission: Dictionary of login information.
|
login_submission: Dictionary of login information.
|
||||||
callback: Callback function to run after registration.
|
callback: Callback function to run after login.
|
||||||
create_non_existent_users: Whether to create the user if they don't
|
create_non_existent_users: Whether to create the user if they don't
|
||||||
exist. Defaults to False.
|
exist. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result: Dictionary of account information after successful registration.
|
result: Dictionary of account information after successful login.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Before we actually log them in we check if they've already logged in
|
# Before we actually log them in we check if they've already logged in
|
||||||
|
@ -339,14 +337,24 @@ class LoginRestServlet(RestServlet):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Handle the final stage of SSO login.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
login_submission: The JSON request body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The body of the JSON response.
|
||||||
|
"""
|
||||||
token = login_submission["token"]
|
token = login_submission["token"]
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
token
|
token
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self._complete_login(user_id, login_submission)
|
return await self._complete_login(
|
||||||
return result
|
user_id, login_submission, self.auth_handler._sso_login_callback
|
||||||
|
)
|
||||||
|
|
||||||
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||||
token = login_submission.get("token", None)
|
token = login_submission.get("token", None)
|
||||||
|
|
|
@ -21,7 +21,6 @@ from mock import Mock, patch
|
||||||
import attr
|
import attr
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
|
@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider):
|
||||||
async def map_user_attributes(self, userinfo, token):
|
async def map_user_attributes(self, userinfo, token):
|
||||||
return {"localpart": userinfo["username"], "display_name": None}
|
return {"localpart": userinfo["username"], "display_name": None}
|
||||||
|
|
||||||
|
# Do not include get_extra_attributes to test backwards compatibility paths.
|
||||||
|
|
||||||
|
|
||||||
|
class TestMappingProviderExtra(TestMappingProvider):
|
||||||
|
async def get_extra_attributes(self, userinfo, token):
|
||||||
|
return {"phone": userinfo["phone"]}
|
||||||
|
|
||||||
|
|
||||||
def simple_async_mock(return_value=None, raises=None):
|
def simple_async_mock(return_value=None, raises=None):
|
||||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
||||||
|
@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
oidc_config = config.get("oidc_config", {})
|
oidc_config = {}
|
||||||
oidc_config["enabled"] = True
|
oidc_config["enabled"] = True
|
||||||
oidc_config["client_id"] = CLIENT_ID
|
oidc_config["client_id"] = CLIENT_ID
|
||||||
oidc_config["client_secret"] = CLIENT_SECRET
|
oidc_config["client_secret"] = CLIENT_SECRET
|
||||||
|
@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
oidc_config["user_mapping_provider"] = {
|
oidc_config["user_mapping_provider"] = {
|
||||||
"module": __name__ + ".TestMappingProvider",
|
"module": __name__ + ".TestMappingProvider",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Update this config with what's in the default config so that
|
||||||
|
# override_config works as expected.
|
||||||
|
oidc_config.update(config.get("oidc_config", {}))
|
||||||
config["oidc_config"] = oidc_config
|
config["oidc_config"] = oidc_config
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
|
@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
|
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
|
||||||
|
|
||||||
@override_config({"oidc_config": {"discover": True}})
|
@override_config({"oidc_config": {"discover": True}})
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_discovery(self):
|
def test_discovery(self):
|
||||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||||
# This would throw if some metadata were invalid
|
# This would throw if some metadata were invalid
|
||||||
metadata = yield defer.ensureDeferred(self.handler.load_metadata())
|
metadata = self.get_success(self.handler.load_metadata())
|
||||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||||
|
|
||||||
self.assertEqual(metadata.issuer, ISSUER)
|
self.assertEqual(metadata.issuer, ISSUER)
|
||||||
|
@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# subsequent calls should be cached
|
# subsequent calls should be cached
|
||||||
self.http_client.reset_mock()
|
self.http_client.reset_mock()
|
||||||
yield defer.ensureDeferred(self.handler.load_metadata())
|
self.get_success(self.handler.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": COMMON_CONFIG})
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_no_discovery(self):
|
def test_no_discovery(self):
|
||||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||||
yield defer.ensureDeferred(self.handler.load_metadata())
|
self.get_success(self.handler.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": COMMON_CONFIG})
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_load_jwks(self):
|
def test_load_jwks(self):
|
||||||
"""JWKS loading is done once (then cached) if used."""
|
"""JWKS loading is done once (then cached) if used."""
|
||||||
jwks = yield defer.ensureDeferred(self.handler.load_jwks())
|
jwks = self.get_success(self.handler.load_jwks())
|
||||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||||
self.assertEqual(jwks, {"keys": []})
|
self.assertEqual(jwks, {"keys": []})
|
||||||
|
|
||||||
# subsequent calls should be cached…
|
# subsequent calls should be cached…
|
||||||
self.http_client.reset_mock()
|
self.http_client.reset_mock()
|
||||||
yield defer.ensureDeferred(self.handler.load_jwks())
|
self.get_success(self.handler.load_jwks())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
# …unless forced
|
# …unless forced
|
||||||
self.http_client.reset_mock()
|
self.http_client.reset_mock()
|
||||||
yield defer.ensureDeferred(self.handler.load_jwks(force=True))
|
self.get_success(self.handler.load_jwks(force=True))
|
||||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||||
|
|
||||||
# Throw if the JWKS uri is missing
|
# Throw if the JWKS uri is missing
|
||||||
with self.metadata_edit({"jwks_uri": None}):
|
with self.metadata_edit({"jwks_uri": None}):
|
||||||
with self.assertRaises(RuntimeError):
|
self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
|
||||||
yield defer.ensureDeferred(self.handler.load_jwks(force=True))
|
|
||||||
|
|
||||||
# Return empty key set if JWKS are not used
|
# Return empty key set if JWKS are not used
|
||||||
self.handler._scopes = [] # not asking the openid scope
|
self.handler._scopes = [] # not asking the openid scope
|
||||||
self.http_client.get_json.reset_mock()
|
self.http_client.get_json.reset_mock()
|
||||||
jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
|
jwks = self.get_success(self.handler.load_jwks(force=True))
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
self.assertEqual(jwks, {"keys": []})
|
self.assertEqual(jwks, {"keys": []})
|
||||||
|
|
||||||
|
@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
# This should not throw
|
# This should not throw
|
||||||
self.handler._validate_metadata()
|
self.handler._validate_metadata()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_redirect_request(self):
|
def test_redirect_request(self):
|
||||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||||
req = Mock(spec=["addCookie"])
|
req = Mock(spec=["addCookie"])
|
||||||
url = yield defer.ensureDeferred(
|
url = self.get_success(
|
||||||
self.handler.handle_redirect_request(req, b"http://client/redirect")
|
self.handler.handle_redirect_request(req, b"http://client/redirect")
|
||||||
)
|
)
|
||||||
url = urlparse(url)
|
url = urlparse(url)
|
||||||
|
@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(params["nonce"], [nonce])
|
self.assertEqual(params["nonce"], [nonce])
|
||||||
self.assertEqual(redirect, "http://client/redirect")
|
self.assertEqual(redirect, "http://client/redirect")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_callback_error(self):
|
def test_callback_error(self):
|
||||||
"""Errors from the provider returned in the callback are displayed."""
|
"""Errors from the provider returned in the callback are displayed."""
|
||||||
self.handler._render_error = Mock()
|
self.handler._render_error = Mock()
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"error"] = [b"invalid_client"]
|
request.args[b"error"] = [b"invalid_client"]
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_client", "")
|
self.assertRenderedError("invalid_client", "")
|
||||||
|
|
||||||
request.args[b"error_description"] = [b"some description"]
|
request.args[b"error_description"] = [b"some description"]
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_client", "some description")
|
self.assertRenderedError("invalid_client", "some description")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_callback(self):
|
def test_callback(self):
|
||||||
"""Code callback works and display errors if something went wrong.
|
"""Code callback works and display errors if something went wrong.
|
||||||
|
|
||||||
|
@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": "foo",
|
"sub": "foo",
|
||||||
"preferred_username": "bar",
|
"preferred_username": "bar",
|
||||||
}
|
}
|
||||||
user_id = UserID("foo", "domain.org")
|
user_id = "@foo:domain.org"
|
||||||
self.handler._render_error = Mock(return_value=None)
|
self.handler._render_error = Mock(return_value=None)
|
||||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
|
@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
client_redirect_url = "http://client/redirect"
|
client_redirect_url = "http://client/redirect"
|
||||||
user_agent = "Browser"
|
user_agent = "Browser"
|
||||||
ip_address = "10.0.0.1"
|
ip_address = "10.0.0.1"
|
||||||
session = self.handler._generate_oidc_session_token(
|
request.getCookie.return_value = self.handler._generate_oidc_session_token(
|
||||||
state=state,
|
state=state,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
client_redirect_url=client_redirect_url,
|
client_redirect_url=client_redirect_url,
|
||||||
ui_auth_session_id=None,
|
ui_auth_session_id=None,
|
||||||
)
|
)
|
||||||
request.getCookie.return_value = session
|
|
||||||
|
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.args[b"code"] = [code.encode("utf-8")]
|
request.args[b"code"] = [code.encode("utf-8")]
|
||||||
|
@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
|
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
|
||||||
request.getClientIP.return_value = ip_address
|
request.getClientIP.return_value = ip_address
|
||||||
|
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user_id, request, client_redirect_url,
|
user_id, request, client_redirect_url, {},
|
||||||
)
|
)
|
||||||
self.handler._exchange_code.assert_called_once_with(code)
|
self.handler._exchange_code.assert_called_once_with(code)
|
||||||
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||||
|
@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.handler._map_userinfo_to_user = simple_async_mock(
|
self.handler._map_userinfo_to_user = simple_async_mock(
|
||||||
raises=MappingException()
|
raises=MappingException()
|
||||||
)
|
)
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("mapping_error")
|
self.assertRenderedError("mapping_error")
|
||||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||||
|
|
||||||
# Handle ID token errors
|
# Handle ID token errors
|
||||||
self.handler._parse_id_token = simple_async_mock(raises=Exception())
|
self.handler._parse_id_token = simple_async_mock(raises=Exception())
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_token")
|
self.assertRenderedError("invalid_token")
|
||||||
|
|
||||||
self.handler._auth_handler.complete_sso_login.reset_mock()
|
self.handler._auth_handler.complete_sso_login.reset_mock()
|
||||||
|
@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# With userinfo fetching
|
# With userinfo fetching
|
||||||
self.handler._scopes = [] # do not ask the "openid" scope
|
self.handler._scopes = [] # do not ask the "openid" scope
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user_id, request, client_redirect_url,
|
user_id, request, client_redirect_url, {},
|
||||||
)
|
)
|
||||||
self.handler._exchange_code.assert_called_once_with(code)
|
self.handler._exchange_code.assert_called_once_with(code)
|
||||||
self.handler._parse_id_token.assert_not_called()
|
self.handler._parse_id_token.assert_not_called()
|
||||||
|
@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# Handle userinfo fetching error
|
# Handle userinfo fetching error
|
||||||
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("fetch_error")
|
self.assertRenderedError("fetch_error")
|
||||||
|
|
||||||
# Handle code exchange failure
|
# Handle code exchange failure
|
||||||
self.handler._exchange_code = simple_async_mock(
|
self.handler._exchange_code = simple_async_mock(
|
||||||
raises=OidcError("invalid_request")
|
raises=OidcError("invalid_request")
|
||||||
)
|
)
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_callback_session(self):
|
def test_callback_session(self):
|
||||||
"""The callback verifies the session presence and validity"""
|
"""The callback verifies the session presence and validity"""
|
||||||
self.handler._render_error = Mock(return_value=None)
|
self.handler._render_error = Mock(return_value=None)
|
||||||
|
@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
# Missing cookie
|
# Missing cookie
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.getCookie.return_value = None
|
request.getCookie.return_value = None
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("missing_session", "No session cookie found")
|
self.assertRenderedError("missing_session", "No session cookie found")
|
||||||
|
|
||||||
# Missing session parameter
|
# Missing session parameter
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.getCookie.return_value = "session"
|
request.getCookie.return_value = "session"
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request", "State parameter is missing")
|
self.assertRenderedError("invalid_request", "State parameter is missing")
|
||||||
|
|
||||||
# Invalid cookie
|
# Invalid cookie
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.args[b"state"] = [b"state"]
|
request.args[b"state"] = [b"state"]
|
||||||
request.getCookie.return_value = "session"
|
request.getCookie.return_value = "session"
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_session")
|
self.assertRenderedError("invalid_session")
|
||||||
|
|
||||||
# Mismatching session
|
# Mismatching session
|
||||||
|
@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.args[b"state"] = [b"mismatching state"]
|
request.args[b"state"] = [b"mismatching state"]
|
||||||
request.getCookie.return_value = session
|
request.getCookie.return_value = session
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("mismatching_session")
|
self.assertRenderedError("mismatching_session")
|
||||||
|
|
||||||
# Valid session
|
# Valid session
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.args[b"state"] = [b"state"]
|
request.args[b"state"] = [b"state"]
|
||||||
request.getCookie.return_value = session
|
request.getCookie.return_value = session
|
||||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
|
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_exchange_code(self):
|
def test_exchange_code(self):
|
||||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||||
token = {"type": "bearer"}
|
token = {"type": "bearer"}
|
||||||
|
@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
|
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
|
||||||
)
|
)
|
||||||
code = "code"
|
code = "code"
|
||||||
ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
|
ret = self.get_success(self.handler._exchange_code(code))
|
||||||
kwargs = self.http_client.request.call_args[1]
|
kwargs = self.http_client.request.call_args[1]
|
||||||
|
|
||||||
self.assertEqual(ret, token)
|
self.assertEqual(ret, token)
|
||||||
|
@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
body=b'{"error": "foo", "error_description": "bar"}',
|
body=b'{"error": "foo", "error_description": "bar"}',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with self.assertRaises(OidcError) as exc:
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
self.assertEqual(exc.value.error, "foo")
|
||||||
self.assertEqual(exc.exception.error, "foo")
|
self.assertEqual(exc.value.error_description, "bar")
|
||||||
self.assertEqual(exc.exception.error_description, "bar")
|
|
||||||
|
|
||||||
# Internal server error with no JSON body
|
# Internal server error with no JSON body
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
|
@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
|
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with self.assertRaises(OidcError) as exc:
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
self.assertEqual(exc.value.error, "server_error")
|
||||||
self.assertEqual(exc.exception.error, "server_error")
|
|
||||||
|
|
||||||
# Internal server error with JSON body
|
# Internal server error with JSON body
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
|
@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
body=b'{"error": "internal_server_error"}',
|
body=b'{"error": "internal_server_error"}',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with self.assertRaises(OidcError) as exc:
|
|
||||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||||
self.assertEqual(exc.exception.error, "internal_server_error")
|
self.assertEqual(exc.value.error, "internal_server_error")
|
||||||
|
|
||||||
# 4xx error without "error" field
|
# 4xx error without "error" field
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
|
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
|
||||||
)
|
)
|
||||||
with self.assertRaises(OidcError) as exc:
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
self.assertEqual(exc.value.error, "server_error")
|
||||||
self.assertEqual(exc.exception.error, "server_error")
|
|
||||||
|
|
||||||
# 2xx error with "error" field
|
# 2xx error with "error" field
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
|
@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
|
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with self.assertRaises(OidcError) as exc:
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
||||||
yield defer.ensureDeferred(self.handler._exchange_code(code))
|
self.assertEqual(exc.value.error, "some_error")
|
||||||
self.assertEqual(exc.exception.error, "some_error")
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_config": {
|
||||||
|
"user_mapping_provider": {
|
||||||
|
"module": __name__ + ".TestMappingProviderExtra"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_extra_attributes(self):
|
||||||
|
"""
|
||||||
|
Login while using a mapping provider that implements get_extra_attributes.
|
||||||
|
"""
|
||||||
|
token = {
|
||||||
|
"type": "bearer",
|
||||||
|
"id_token": "id_token",
|
||||||
|
"access_token": "access_token",
|
||||||
|
}
|
||||||
|
userinfo = {
|
||||||
|
"sub": "foo",
|
||||||
|
"phone": "1234567",
|
||||||
|
}
|
||||||
|
user_id = "@foo:domain.org"
|
||||||
|
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||||
|
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
|
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||||
|
self.handler._auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
request = Mock(
|
||||||
|
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
|
||||||
|
)
|
||||||
|
|
||||||
|
state = "state"
|
||||||
|
client_redirect_url = "http://client/redirect"
|
||||||
|
request.getCookie.return_value = self.handler._generate_oidc_session_token(
|
||||||
|
state=state,
|
||||||
|
nonce="nonce",
|
||||||
|
client_redirect_url=client_redirect_url,
|
||||||
|
ui_auth_session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
request.args = {}
|
||||||
|
request.args[b"code"] = [b"code"]
|
||||||
|
request.args[b"state"] = [state.encode("utf-8")]
|
||||||
|
|
||||||
|
request.requestHeaders = Mock(spec=["getRawHeaders"])
|
||||||
|
request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
|
||||||
|
request.getClientIP.return_value = "10.0.0.1"
|
||||||
|
|
||||||
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
|
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
user_id, request, client_redirect_url, {"phone": "1234567"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_map_userinfo_to_user(self):
|
def test_map_userinfo_to_user(self):
|
||||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||||
|
|
Loading…
Reference in New Issue