Add soft logout possiility for OIDC ackchannel_logout

This commit is contained in:
Mohamed Hachem OUERTANI 2023-07-21 17:59:25 +04:00 committed by Yadd
parent 614efc488b
commit 671fc2d81b
7 changed files with 311 additions and 7 deletions

View File

@ -0,0 +1 @@
Allow setting OIDC backchannel logout to be a soft-logout in OIDC provider configuration via `backchannel_logout_is_soft` which defaults to false. Contributed by @hachem2001.

View File

@ -3335,6 +3335,12 @@ Options for each entry include:
You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`.
* `backchannel_logout_is_soft`: by default all OIDC Back-Channel Logouts correspond to hard logouts on
the server side. This may not leave users the ability to recover their encryption keys before being logged-out.
This can be set to `true` to treat all OIDC Back-Channel logouts as soft-logouts,
allowing users to reconnect to the same device if necessary to recover their keys.
Defaults to `false`.
It is possible to configure Synapse to only allow logins if certain attributes
match particular values in the OIDC userinfo. The requirements can be listed under
`attribute_requirements` as shown here:

View File

@ -126,6 +126,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"skip_verification": {"type": "boolean"},
"backchannel_logout_enabled": {"type": "boolean"},
"backchannel_logout_ignore_sub": {"type": "boolean"},
"backchannel_logout_is_soft": {"type": "boolean"},
"user_profile_method": {
"type": "string",
"enum": ["auto", "userinfo_endpoint"],
@ -301,6 +302,7 @@ def _parse_oidc_config_dict(
backchannel_logout_ignore_sub=oidc_config.get(
"backchannel_logout_ignore_sub", False
),
backchannel_logout_is_soft=oidc_config.get("backchannel_logout_is_soft", False),
skip_verification=oidc_config.get("skip_verification", False),
user_profile_method=oidc_config.get("user_profile_method", "auto"),
allow_existing_users=oidc_config.get("allow_existing_users", False),
@ -388,6 +390,9 @@ class OidcProviderConfig:
# Whether Synapse should ignore the `sub` claim in backchannel logouts or not.
backchannel_logout_ignore_sub: bool
# Whether Synapse should consider backchannel logouts as soft-logouts. Default false
backchannel_logout_is_soft: bool
# Whether to skip metadata verification
skip_verification: bool

View File

@ -1340,19 +1340,44 @@ class OidcProvider:
self.idp_id, sub
)
# Invalidate any running user-mapping sessions, in-flight login tokens and
# active devices
await self._sso_handler.revoke_sessions_for_provider_session_id(
auth_provider_id=self.idp_id,
auth_provider_session_id=sid,
expected_user_id=expected_user_id,
)
# Back-Channel Logout can be set to only soft-logout users in the config, hence
# this check. The aim of the config is not to surprise the user with a sudden
# hard logout deleting his devices and keys in the process, thus not allowing
# him to set a recovery method/recover keys/... as a result of this back-channel
# logout.
if self._config.backchannel_logout_is_soft:
await self._handle_backchannel_soft_logout(request, sid, expected_user_id)
else:
# Invalidate any running user-mapping sessions, in-flight login tokens and
# active devices
await self._sso_handler.revoke_sessions_for_provider_session_id(
auth_provider_id=self.idp_id,
auth_provider_session_id=sid,
expected_user_id=expected_user_id,
)
request.setResponseCode(200)
request.setHeader(b"Cache-Control", b"no-cache, no-store")
request.setHeader(b"Pragma", b"no-cache")
finish_request(request)
async def _handle_backchannel_soft_logout(
self, request: SynapseRequest, sid: str, expected_user_id: Optional[str] = None
) -> None:
"""Helper function called when handling an incoming request to
/_synapse/client/oidc/backchannel_logout
ONLY when OIDC is set with parameter backchannel_logout_is_soft:true
Makes a soft_logout on all the user's tokens. (Does not delete devices)
"""
# Invalidate any running user-mapping sessions, in-flight login tokens and
# active devices
await self._sso_handler.invalidate_sessions_for_provider_session_id(
auth_provider_id=self.idp_id,
auth_provider_session_id=sid,
expected_user_id=expected_user_id,
)
class LogoutToken(JWTClaims): # type: ignore[misc]
"""

View File

@ -1232,6 +1232,91 @@ class SsoHandler:
)
await self._device_handler.delete_devices(user_id, [device_id])
async def invalidate_sessions_for_provider_session_id(
self,
auth_provider_id: str,
auth_provider_session_id: str,
expected_user_id: Optional[str] = None,
) -> None:
"""Invalidates all access tokens and in-flight login tokens tied to a provider
session.
This causes them to be soft-logged out.
Can only be called from the main process.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
auth_provider_session_id: The session ID from the provider to logout
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""
# It is expected that this is the main process.
assert isinstance(
self._device_handler, DeviceHandler
), "invalidating SSO sessions can only be called on the main process"
# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():
if (
session.auth_provider_id == auth_provider_id
and session.auth_provider_session_id == auth_provider_session_id
):
to_delete.append(session_id)
for session_id in to_delete:
logger.info("Revoking mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
# Invalidate any in-flight login tokens
await self._store.invalidate_login_tokens_by_session_id(
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
# Fetch any device(s) in the store associated with the session ID.
devices = await self._store.get_devices_by_auth_provider_session_id(
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
# Iterate over the list of devices and soft-log them out one by one.
for device in devices:
user_id = device["user_id"]
device_id = device["device_id"]
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:
logger.error(
"Received a (soft) logout notification from SSO provider "
f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
f"a session ID ({auth_provider_session_id!r}) which belongs to "
f"{user_id!r}. This may happen when the SSO provider user mapper "
"uses something else than the standard attribute as mapping ID. "
"For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
"in the provider config if that is the case."
)
continue
logger.info(
"Soft-logging out %r (device %r) via SSO (%r) soft-logout notification (session %r).",
user_id,
device_id,
auth_provider_id,
auth_provider_session_id,
)
# Invalidate all tokens of user_id associated with device_id
await self._store.user_set_account_tokens_validity(
user_id,
validity_until_ms=self._clock.time_msec(),
except_token_id=None,
device_id=device_id,
)
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie

View File

@ -2616,6 +2616,67 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def user_set_account_tokens_validity(
self,
user_id: str,
validity_until_ms: int = 0,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Set access tokens' validi_until_ms belonging to a user
Sets the same value for all of the concerned tokens
Args:
user_id: ID of user the tokens belong to
validity_until_ms: New validity_until value for all considered tokens
except_token_id: access_tokens ID which should *not* be updated
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be updated
Returns:
A tuple of (token, token id, device id) for each of the updated tokens
"""
assert validity_until_ms >= 0
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values: List[Union[str, int]] = [v for _, v in items]
values.copy()
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
txn.execute(
"SELECT token, id, device_id FROM access_tokens WHERE %s"
% where_clause,
values,
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
for token, token_id, _ in tokens_and_devices:
self.db_pool.simple_update_txn(
txn,
table="access_tokens",
keyvalues={"id": token_id},
updatevalues={"valid_until_ms": validity_until_ms},
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,)
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return tokens_and_devices
return await self.db_pool.runInteraction("user_set_account_tokens_validity", f)
async def delete_access_token(self, access_token: str) -> None:
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(

View File

@ -1276,6 +1276,59 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
@override_config(
{
"oidc_providers": [
oidc_config(
id="oidc",
with_localpart_template=True,
backchannel_logout_enabled=True,
backchannel_logout_is_soft=True,
)
]
}
)
def test_simple_logout_is_soft(self) -> None:
"""
Soft-logout on back-channel option being enabled,
receiving a logout token should soft-logout the user
"""
fake_oidc_server = self.helper.fake_oidc_server()
user = "john"
login_resp, first_grant = self.helper.login_via_oidc(
fake_oidc_server, user, with_sid=True
)
first_access_token: str = login_resp["access_token"]
self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
login_resp, second_grant = self.helper.login_via_oidc(
fake_oidc_server, user, with_sid=True
)
second_access_token: str = login_resp["access_token"]
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
self.assertNotEqual(first_grant.sid, second_grant.sid)
self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
# Soft-logging out of the first session
logout_token = fake_oidc_server.generate_logout_token(first_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
first_attempt_json_body = self.helper.whoami(
first_access_token, expect_code=HTTPStatus.UNAUTHORIZED
)
self.assertEqual(first_attempt_json_body["soft_logout"], True)
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
# Soft-logging out of the second session
logout_token = fake_oidc_server.generate_logout_token(second_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
# Soft-logout option does not change behaviour during mapping or login
@override_config(
{
"oidc_providers": [
@ -1500,3 +1553,71 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
@override_config(
{
"oidc_providers": [
oidc_config(
"first",
issuer="https://first-issuer.com/",
with_localpart_template=True,
backchannel_logout_enabled=True,
backchannel_logout_is_soft=True,
),
oidc_config(
"second",
issuer="https://second-issuer.com/",
with_localpart_template=True,
backchannel_logout_enabled=True,
backchannel_logout_is_soft=True,
),
]
}
)
def test_multiple_providers_is_soft(self) -> None:
"""
It should be able to distinguish login tokens from two different IdPs
"""
first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
second_server = self.helper.fake_oidc_server(
issuer="https://second-issuer.com/"
)
user = "john"
login_resp, first_grant = self.helper.login_via_oidc(
first_server, user, with_sid=True, idp_id="oidc-first"
)
first_access_token: str = login_resp["access_token"]
self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
login_resp, second_grant = self.helper.login_via_oidc(
second_server, user, with_sid=True, idp_id="oidc-second"
)
second_access_token: str = login_resp["access_token"]
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
# `sid` in the fake providers are generated by a counter, so the first grant of
# each provider should give the same SID
self.assertEqual(first_grant.sid, second_grant.sid)
self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
# Logging out of the first session
logout_token = first_server.generate_logout_token(first_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
first_attempt_json_body = self.helper.whoami(
first_access_token, expect_code=HTTPStatus.UNAUTHORIZED
)
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
self.assertEqual(first_attempt_json_body["soft_logout"], True)
# Logging out of the second session
logout_token = second_server.generate_logout_token(second_grant)
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
second_attempt_json_body = self.helper.whoami(
second_access_token, expect_code=HTTPStatus.UNAUTHORIZED
)
self.assertEqual(second_attempt_json_body["soft_logout"], True)