Added display_name_claim in jwt_config which sets the user's display name upon registration (#17708)

This commit is contained in:
Nathan 2024-10-09 14:21:08 +02:00 committed by GitHub
parent 60aebdb27e
commit 05576f0b4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 50 additions and 6 deletions

View File

@ -0,0 +1 @@
Added the `display_name_claim` option to the JWT configuration. This option allows specifying the claim key that contains the user's display name in the JWT payload.

View File

@ -3722,6 +3722,8 @@ Additional sub-options for this setting include:
Required if `enabled` is set to true.
* `subject_claim`: Name of the claim containing a unique identifier for the user.
Optional, defaults to `sub`.
* `display_name_claim`: Name of the claim containing the display name for the user. Optional.
If provided, the display name will be set to the value of this claim upon first login.
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
"iss" claim will be required and validated for all JSON web tokens.
* `audiences`: A list of audiences to validate the "aud" claim against. Optional.
@ -3736,6 +3738,7 @@ jwt_config:
secret: "provided-by-your-issuer"
algorithm: "provided-by-your-issuer"
subject_claim: "name_of_claim"
display_name_claim: "name_of_claim"
issuer: "provided-by-your-issuer"
audiences:
- "provided-by-your-issuer"

View File

@ -38,6 +38,7 @@ class JWTConfig(Config):
self.jwt_algorithm = jwt_config["algorithm"]
self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
self.jwt_display_name_claim = jwt_config.get("display_name_claim")
# The issuer and audiences are optional, if provided, it is asserted
# that the claims exist on the JWT.
@ -49,5 +50,6 @@ class JWTConfig(Config):
self.jwt_secret = None
self.jwt_algorithm = None
self.jwt_subject_claim = None
self.jwt_display_name_claim = None
self.jwt_issuer = None
self.jwt_audiences = None

View File

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple
from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
@ -36,11 +36,12 @@ class JwtHandler:
self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences
def validate_login(self, login_submission: JsonDict) -> str:
def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]:
"""
Authenticates the user for the /login API
@ -49,7 +50,8 @@ class JwtHandler:
(including 'type' and other relevant fields)
Returns:
The user ID that is logging in.
A tuple of (user_id, display_name) of the user that is logging in.
If the JWT does not contain a display name, the second element of the tuple will be None.
Raises:
LoginError if there was an authentication problem.
@ -109,4 +111,10 @@ class JwtHandler:
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
return UserID(user, self.hs.hostname).to_string()
default_display_name = None
if self.jwt_display_name_claim:
display_name_claim = claims.get(self.jwt_display_name_claim)
if display_name_claim is not None:
default_display_name = display_name_claim
return UserID(user, self.hs.hostname).to_string(), default_display_name

View File

@ -363,6 +363,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
default_display_name: Optional[str] = None,
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
@ -410,7 +411,8 @@ class LoginRestServlet(RestServlet):
canonical_uid = await self.auth_handler.check_user_exists(user_id)
if not canonical_uid:
canonical_uid = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart
localpart=UserID.from_string(user_id).localpart,
default_display_name=default_display_name,
)
user_id = canonical_uid
@ -546,11 +548,14 @@ class LoginRestServlet(RestServlet):
Returns:
The body of the JSON response.
"""
user_id = self.hs.get_jwt_handler().validate_login(login_submission)
user_id, default_display_name = self.hs.get_jwt_handler().validate_login(
login_submission
)
return await self._complete_login(
user_id,
login_submission,
create_non_existent_users=True,
default_display_name=default_display_name,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)

View File

@ -1047,6 +1047,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
profile.register_servlets,
]
jwt_secret = "secret"
@ -1202,6 +1203,30 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
@override_config(
{"jwt_config": {**base_config, "display_name_claim": "display_name"}}
)
def test_login_custom_display_name(self) -> None:
"""Test setting a custom display name."""
localpart = "pinkie"
user_id = f"@{localpart}:test"
display_name = "Pinkie Pie"
# Perform the login, specifying a custom display name.
channel = self.jwt_login({"sub": localpart, "display_name": display_name})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], user_id)
# Fetch the user's display name and check that it was set correctly.
access_token = channel.json_body["access_token"]
channel = self.make_request(
"GET",
f"/_matrix/client/v3/profile/{user_id}/displayname",
access_token=access_token,
)
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["displayname"], display_name)
def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)