Initialise user displayname from SAML2 data (#4272)
When we register a new user from SAML2 data, initialise their displayname correctly.
This commit is contained in:
parent
35e13477cf
commit
30da50a5b8
|
@ -0,0 +1 @@
|
||||||
|
SAML2 authentication: Initialise user display name from SAML2 data
|
|
@ -126,6 +126,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
make_guest=False,
|
make_guest=False,
|
||||||
admin=False,
|
admin=False,
|
||||||
threepid=None,
|
threepid=None,
|
||||||
|
default_display_name=None,
|
||||||
):
|
):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
|
@ -140,6 +141,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
since it offers no means of associating a device_id with the
|
since it offers no means of associating a device_id with the
|
||||||
access_token. Instead you should call auth_handler.issue_access_token
|
access_token. Instead you should call auth_handler.issue_access_token
|
||||||
after registration.
|
after registration.
|
||||||
|
default_display_name (unicode|None): if set, the new user's displayname
|
||||||
|
will be set to this. Defaults to 'localpart'.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, access_token).
|
A tuple of (user_id, access_token).
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -169,6 +172,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
if was_guest:
|
||||||
|
# If the user was a guest then they already have a profile
|
||||||
|
default_display_name = None
|
||||||
|
|
||||||
|
elif default_display_name is None:
|
||||||
|
default_display_name = localpart
|
||||||
|
|
||||||
token = None
|
token = None
|
||||||
if generate_token:
|
if generate_token:
|
||||||
token = self.macaroon_gen.generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
|
@ -178,10 +188,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
was_guest=was_guest,
|
was_guest=was_guest,
|
||||||
make_guest=make_guest,
|
make_guest=make_guest,
|
||||||
create_profile_with_localpart=(
|
create_profile_with_displayname=default_display_name,
|
||||||
# If the user was a guest then they already have a profile
|
|
||||||
None if was_guest else user.localpart
|
|
||||||
),
|
|
||||||
admin=admin,
|
admin=admin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -203,13 +210,15 @@ class RegistrationHandler(BaseHandler):
|
||||||
yield self.check_user_id_not_appservice_exclusive(user_id)
|
yield self.check_user_id_not_appservice_exclusive(user_id)
|
||||||
if generate_token:
|
if generate_token:
|
||||||
token = self.macaroon_gen.generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
|
if default_display_name is None:
|
||||||
|
default_display_name = localpart
|
||||||
try:
|
try:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
make_guest=make_guest,
|
make_guest=make_guest,
|
||||||
create_profile_with_localpart=user.localpart,
|
create_profile_with_displayname=default_display_name,
|
||||||
)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
# if user id is taken, just generate another
|
# if user id is taken, just generate another
|
||||||
|
@ -300,7 +309,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash="",
|
password_hash="",
|
||||||
appservice_id=service_id,
|
appservice_id=service_id,
|
||||||
create_profile_with_localpart=user.localpart,
|
create_profile_with_displayname=user.localpart,
|
||||||
)
|
)
|
||||||
defer.returnValue(user_id)
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
|
@ -478,7 +487,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
create_profile_with_localpart=user.localpart,
|
create_profile_with_displayname=user.localpart,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||||
|
|
|
@ -451,6 +451,7 @@ class SSOAuthHandler(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_successful_auth(
|
def on_successful_auth(
|
||||||
self, username, request, client_redirect_url,
|
self, username, request, client_redirect_url,
|
||||||
|
user_display_name=None,
|
||||||
):
|
):
|
||||||
"""Called once the user has successfully authenticated with the SSO.
|
"""Called once the user has successfully authenticated with the SSO.
|
||||||
|
|
||||||
|
@ -467,6 +468,9 @@ class SSOAuthHandler(object):
|
||||||
client_redirect_url (unicode): the redirect_url the client gave us when
|
client_redirect_url (unicode): the redirect_url the client gave us when
|
||||||
it first started the process.
|
it first started the process.
|
||||||
|
|
||||||
|
user_display_name (unicode|None): if set, and we have to register a new user,
|
||||||
|
we will set their displayname to this.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[none]: Completes once we have handled the request.
|
Deferred[none]: Completes once we have handled the request.
|
||||||
"""
|
"""
|
||||||
|
@ -478,6 +482,7 @@ class SSOAuthHandler(object):
|
||||||
yield self._registration_handler.register(
|
yield self._registration_handler.register(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
|
default_display_name=user_display_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,9 @@ class SAML2ResponseResource(Resource):
|
||||||
raise CodeMessageException(400, "uid not in SAML2 response")
|
raise CodeMessageException(400, "uid not in SAML2 response")
|
||||||
|
|
||||||
username = saml2_auth.ava["uid"][0]
|
username = saml2_auth.ava["uid"][0]
|
||||||
|
|
||||||
|
displayName = saml2_auth.ava.get("displayName", [None])[0]
|
||||||
return self._sso_auth_handler.on_successful_auth(
|
return self._sso_auth_handler.on_successful_auth(
|
||||||
username, request, relay_state,
|
username, request, relay_state,
|
||||||
|
user_display_name=displayName,
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import Codes, StoreError
|
from synapse.api.errors import Codes, StoreError
|
||||||
from synapse.storage import background_updates
|
from synapse.storage import background_updates
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.types import UserID
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,7 +168,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||||
|
|
||||||
def register(self, user_id, token=None, password_hash=None,
|
def register(self, user_id, token=None, password_hash=None,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_localpart=None, admin=False):
|
create_profile_with_displayname=None, admin=False):
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -181,8 +182,8 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||||
make_guest (boolean): True if the the new user should be guest,
|
make_guest (boolean): True if the the new user should be guest,
|
||||||
false to add a regular user account.
|
false to add a regular user account.
|
||||||
appservice_id (str): The ID of the appservice registering the user.
|
appservice_id (str): The ID of the appservice registering the user.
|
||||||
create_profile_with_localpart (str): Optionally create a profile for
|
create_profile_with_displayname (unicode): Optionally create a profile for
|
||||||
the given localpart.
|
the user, setting their displayname to the given value
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if the user_id could not be registered.
|
StoreError if the user_id could not be registered.
|
||||||
"""
|
"""
|
||||||
|
@ -195,7 +196,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||||
was_guest,
|
was_guest,
|
||||||
make_guest,
|
make_guest,
|
||||||
appservice_id,
|
appservice_id,
|
||||||
create_profile_with_localpart,
|
create_profile_with_displayname,
|
||||||
admin
|
admin
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -208,9 +209,11 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||||
was_guest,
|
was_guest,
|
||||||
make_guest,
|
make_guest,
|
||||||
appservice_id,
|
appservice_id,
|
||||||
create_profile_with_localpart,
|
create_profile_with_displayname,
|
||||||
admin,
|
admin,
|
||||||
):
|
):
|
||||||
|
user_id_obj = UserID.from_string(user_id)
|
||||||
|
|
||||||
now = int(self.clock.time())
|
now = int(self.clock.time())
|
||||||
|
|
||||||
next_id = self._access_tokens_id_gen.get_next()
|
next_id = self._access_tokens_id_gen.get_next()
|
||||||
|
@ -273,12 +276,15 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||||
(next_id, user_id, token,)
|
(next_id, user_id, token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if create_profile_with_localpart:
|
if create_profile_with_displayname:
|
||||||
# set a default displayname serverside to avoid ugly race
|
# set a default displayname serverside to avoid ugly race
|
||||||
# between auto-joins and clients trying to set displaynames
|
# between auto-joins and clients trying to set displaynames
|
||||||
|
#
|
||||||
|
# *obviously* the 'profiles' table uses localpart for user_id
|
||||||
|
# while everything else uses the full mxid.
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
|
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
|
||||||
(create_profile_with_localpart, create_profile_with_localpart)
|
(user_id_obj.localpart, create_profile_with_displayname)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
|
|
|
@ -149,7 +149,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def test_populate_monthly_users_is_guest(self):
|
def test_populate_monthly_users_is_guest(self):
|
||||||
# Test that guest users are not added to mau list
|
# Test that guest users are not added to mau list
|
||||||
user_id = "user_id"
|
user_id = "@user_id:host"
|
||||||
self.store.register(
|
self.store.register(
|
||||||
user_id=user_id, token="123", password_hash=None, make_guest=True
|
user_id=user_id, token="123", password_hash=None, make_guest=True
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue