Properly report user-agent/IP during registration of SSO users. (#8784)
This also expands type-hints to the SSO and registration code. Refactors the CAS code to more closely match OIDC/SAML.
This commit is contained in:
parent
7127855741
commit
6fde6aa9c0
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug introduced in v1.20.0 where the user-agent and IP address reported during user registration for CAS, OpenID Connect, and SAML were of the wrong form.
|
1
mypy.ini
1
mypy.ini
|
@ -37,6 +37,7 @@ files =
|
||||||
synapse/handlers/presence.py,
|
synapse/handlers/presence.py,
|
||||||
synapse/handlers/profile.py,
|
synapse/handlers/profile.py,
|
||||||
synapse/handlers/read_marker.py,
|
synapse/handlers/read_marker.py,
|
||||||
|
synapse/handlers/register.py,
|
||||||
synapse/handlers/room.py,
|
synapse/handlers/room.py,
|
||||||
synapse/handlers/room_member.py,
|
synapse/handlers/room_member.py,
|
||||||
synapse/handlers/room_member_worker.py,
|
synapse/handlers/room_member_worker.py,
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,10 +34,10 @@ class CasHandler:
|
||||||
Utility class for to handle the response from a CAS SSO service.
|
Utility class for to handle the response from a CAS SSO service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer)
|
hs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._hostname = hs.hostname
|
self._hostname = hs.hostname
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
@ -200,27 +203,57 @@ class CasHandler:
|
||||||
args["session"] = session
|
args["session"] = session
|
||||||
username, user_display_name = await self._validate_ticket(ticket, args)
|
username, user_display_name = await self._validate_ticket(ticket, args)
|
||||||
|
|
||||||
localpart = map_username_to_mxid_localpart(username)
|
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
|
||||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
|
||||||
|
|
||||||
if session:
|
|
||||||
await self._auth_handler.complete_sso_ui_auth(
|
|
||||||
registered_user_id, session, request,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if not registered_user_id:
|
|
||||||
# Pull out the user-agent and IP from the request.
|
# Pull out the user-agent and IP from the request.
|
||||||
user_agent = request.get_user_agent("")
|
user_agent = request.get_user_agent("")
|
||||||
ip_address = self.hs.get_ip_from_request(request)
|
ip_address = self.hs.get_ip_from_request(request)
|
||||||
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
# Get the matrix ID from the CAS username.
|
||||||
localpart=localpart,
|
user_id = await self._map_cas_user_to_matrix_user(
|
||||||
default_display_name=user_display_name,
|
username, user_display_name, user_agent, ip_address
|
||||||
user_agent_ips=(user_agent, ip_address),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._auth_handler.complete_sso_login(
|
if session:
|
||||||
registered_user_id, request, client_redirect_url
|
await self._auth_handler.complete_sso_ui_auth(
|
||||||
|
user_id, session, request,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# If this not a UI auth request than there must be a redirect URL.
|
||||||
|
assert client_redirect_url
|
||||||
|
|
||||||
|
await self._auth_handler.complete_sso_login(
|
||||||
|
user_id, request, client_redirect_url
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _map_cas_user_to_matrix_user(
|
||||||
|
self,
|
||||||
|
remote_user_id: str,
|
||||||
|
display_name: Optional[str],
|
||||||
|
user_agent: str,
|
||||||
|
ip_address: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Given a CAS username, retrieve the user ID for it and possibly register the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_user_id: The username from the CAS response.
|
||||||
|
display_name: The display name from the CAS response.
|
||||||
|
user_agent: The user agent of the client making the request.
|
||||||
|
ip_address: The IP address of the client making the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The user ID associated with this response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
localpart = map_username_to_mxid_localpart(remote_user_id)
|
||||||
|
user_id = UserID(localpart, self._hostname).to_string()
|
||||||
|
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||||
|
|
||||||
|
# If the user does not exist, register it.
|
||||||
|
if not registered_user_id:
|
||||||
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
|
localpart=localpart,
|
||||||
|
default_display_name=display_name,
|
||||||
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
|
)
|
||||||
|
|
||||||
|
return registered_user_id
|
||||||
|
|
|
@ -925,7 +925,7 @@ class OidcHandler(BaseHandler):
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=attributes["display_name"],
|
default_display_name=attributes["display_name"],
|
||||||
user_agent_ips=(user_agent, ip_address),
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.record_user_external_id(
|
await self.store.record_user_external_id(
|
||||||
|
|
|
@ -15,10 +15,12 @@
|
||||||
|
|
||||||
"""Contains functions for registering clients."""
|
"""Contains functions for registering clients."""
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
||||||
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
|
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
|
||||||
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.config.server import is_threepid_reserved
|
from synapse.config.server import is_threepid_reserved
|
||||||
from synapse.http.servlet import assert_params_in_dict
|
from synapse.http.servlet import assert_params_in_dict
|
||||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||||
|
@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationHandler(BaseHandler):
|
class RegistrationHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer):
|
|
||||||
"""
|
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -71,7 +71,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
self.session_lifetime = hs.config.session_lifetime
|
self.session_lifetime = hs.config.session_lifetime
|
||||||
|
|
||||||
async def check_username(
|
async def check_username(
|
||||||
self, localpart, guest_access_token=None, assigned_user_id=None
|
self,
|
||||||
|
localpart: str,
|
||||||
|
guest_access_token: Optional[str] = None,
|
||||||
|
assigned_user_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if types.contains_invalid_mxid_characters(localpart):
|
if types.contains_invalid_mxid_characters(localpart):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -140,39 +143,45 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
async def register_user(
|
async def register_user(
|
||||||
self,
|
self,
|
||||||
localpart=None,
|
localpart: Optional[str] = None,
|
||||||
password_hash=None,
|
password_hash: Optional[str] = None,
|
||||||
guest_access_token=None,
|
guest_access_token: Optional[str] = None,
|
||||||
make_guest=False,
|
make_guest: bool = False,
|
||||||
admin=False,
|
admin: bool = False,
|
||||||
threepid=None,
|
threepid: Optional[dict] = None,
|
||||||
user_type=None,
|
user_type: Optional[str] = None,
|
||||||
default_display_name=None,
|
default_display_name: Optional[str] = None,
|
||||||
address=None,
|
address: Optional[str] = None,
|
||||||
bind_emails=[],
|
bind_emails: List[str] = [],
|
||||||
by_admin=False,
|
by_admin: bool = False,
|
||||||
user_agent_ips=None,
|
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||||
):
|
) -> str:
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
localpart: The local part of the user ID to register. If None,
|
localpart: The local part of the user ID to register. If None,
|
||||||
one will be generated.
|
one will be generated.
|
||||||
password_hash (str|None): The hashed password to assign to this user so they can
|
password_hash: The hashed password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
user_type (str|None): type of user. One of the values from
|
guest_access_token: The access token used when this was a guest
|
||||||
|
account.
|
||||||
|
make_guest: True if the the new user should be guest,
|
||||||
|
false to add a regular user account.
|
||||||
|
admin: True if the user should be registered as a server admin.
|
||||||
|
threepid: The threepid used for registering, if any.
|
||||||
|
user_type: type of user. One of the values from
|
||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
default_display_name (unicode|None): if set, the new user's displayname
|
default_display_name: if set, the new user's displayname
|
||||||
will be set to this. Defaults to 'localpart'.
|
will be set to this. Defaults to 'localpart'.
|
||||||
address (str|None): the IP address used to perform the registration.
|
address: the IP address used to perform the registration.
|
||||||
bind_emails (List[str]): list of emails to bind to this account.
|
bind_emails: list of emails to bind to this account.
|
||||||
by_admin (bool): True if this registration is being made via the
|
by_admin: True if this registration is being made via the
|
||||||
admin api, otherwise False.
|
admin api, otherwise False.
|
||||||
user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
|
user_agent_ips: Tuples of IP addresses and user-agents used
|
||||||
during the registration process.
|
during the registration process.
|
||||||
Returns:
|
Returns:
|
||||||
str: user_id
|
The registere user_id.
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem registering.
|
SynapseError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
|
@ -236,8 +245,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
# autogen a sequential user ID
|
# autogen a sequential user ID
|
||||||
fail_count = 0
|
fail_count = 0
|
||||||
user = None
|
# If a default display name is not given, generate one.
|
||||||
while not user:
|
generate_display_name = default_display_name is None
|
||||||
|
# This breaks on successful registration *or* errors after 10 failures.
|
||||||
|
while True:
|
||||||
# Fail after being unable to find a suitable ID a few times
|
# Fail after being unable to find a suitable ID a few times
|
||||||
if fail_count > 10:
|
if fail_count > 10:
|
||||||
raise SynapseError(500, "Unable to find a suitable guest user ID")
|
raise SynapseError(500, "Unable to find a suitable guest user ID")
|
||||||
|
@ -246,7 +257,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
self.check_user_id_not_appservice_exclusive(user_id)
|
self.check_user_id_not_appservice_exclusive(user_id)
|
||||||
if default_display_name is None:
|
if generate_display_name:
|
||||||
default_display_name = localpart
|
default_display_name = localpart
|
||||||
try:
|
try:
|
||||||
await self.register_with_store(
|
await self.register_with_store(
|
||||||
|
@ -262,8 +273,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
break
|
break
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
# if user id is taken, just generate another
|
# if user id is taken, just generate another
|
||||||
user = None
|
|
||||||
user_id = None
|
|
||||||
fail_count += 1
|
fail_count += 1
|
||||||
|
|
||||||
if not self.hs.config.user_consent_at_registration:
|
if not self.hs.config.user_consent_at_registration:
|
||||||
|
@ -295,7 +304,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
async def _create_and_join_rooms(self, user_id: str):
|
async def _create_and_join_rooms(self, user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Create the auto-join rooms and join or invite the user to them.
|
Create the auto-join rooms and join or invite the user to them.
|
||||||
|
|
||||||
|
@ -379,7 +388,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to join new user to %r: %r", r, e)
|
logger.error("Failed to join new user to %r: %r", r, e)
|
||||||
|
|
||||||
async def _join_rooms(self, user_id: str):
|
async def _join_rooms(self, user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Join or invite the user to the auto-join rooms.
|
Join or invite the user to the auto-join rooms.
|
||||||
|
|
||||||
|
@ -425,6 +434,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
# Send the invite, if necessary.
|
# Send the invite, if necessary.
|
||||||
if requires_invite:
|
if requires_invite:
|
||||||
|
# If an invite is required, there must be a auto-join user ID.
|
||||||
|
assert self.hs.config.registration.auto_join_user_id
|
||||||
|
|
||||||
await room_member_handler.update_membership(
|
await room_member_handler.update_membership(
|
||||||
requester=create_requester(
|
requester=create_requester(
|
||||||
self.hs.config.registration.auto_join_user_id,
|
self.hs.config.registration.auto_join_user_id,
|
||||||
|
@ -456,7 +468,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to join new user to %r: %r", r, e)
|
logger.error("Failed to join new user to %r: %r", r, e)
|
||||||
|
|
||||||
async def _auto_join_rooms(self, user_id: str):
|
async def _auto_join_rooms(self, user_id: str) -> None:
|
||||||
"""Automatically joins users to auto join rooms - creating the room in the first place
|
"""Automatically joins users to auto join rooms - creating the room in the first place
|
||||||
if the user is the first to be created.
|
if the user is the first to be created.
|
||||||
|
|
||||||
|
@ -479,16 +491,16 @@ class RegistrationHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
await self._join_rooms(user_id)
|
await self._join_rooms(user_id)
|
||||||
|
|
||||||
async def post_consent_actions(self, user_id):
|
async def post_consent_actions(self, user_id: str) -> None:
|
||||||
"""A series of registration actions that can only be carried out once consent
|
"""A series of registration actions that can only be carried out once consent
|
||||||
has been granted
|
has been granted
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user to join
|
user_id: The user to join
|
||||||
"""
|
"""
|
||||||
await self._auto_join_rooms(user_id)
|
await self._auto_join_rooms(user_id)
|
||||||
|
|
||||||
async def appservice_register(self, user_localpart, as_token):
|
async def appservice_register(self, user_localpart: str, as_token: str) -> str:
|
||||||
user = UserID(user_localpart, self.hs.hostname)
|
user = UserID(user_localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
service = self.store.get_app_service_by_token(as_token)
|
service = self.store.get_app_service_by_token(as_token)
|
||||||
|
@ -513,7 +525,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
|
def check_user_id_not_appservice_exclusive(
|
||||||
|
self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
|
||||||
|
) -> None:
|
||||||
# don't allow people to register the server notices mxid
|
# don't allow people to register the server notices mxid
|
||||||
if self._server_notices_mxid is not None:
|
if self._server_notices_mxid is not None:
|
||||||
if user_id == self._server_notices_mxid:
|
if user_id == self._server_notices_mxid:
|
||||||
|
@ -537,12 +551,12 @@ class RegistrationHandler(BaseHandler):
|
||||||
errcode=Codes.EXCLUSIVE,
|
errcode=Codes.EXCLUSIVE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_registration_ratelimit(self, address):
|
def check_registration_ratelimit(self, address: Optional[str]) -> None:
|
||||||
"""A simple helper method to check whether the registration rate limit has been hit
|
"""A simple helper method to check whether the registration rate limit has been hit
|
||||||
for a given IP address
|
for a given IP address
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
address (str|None): the IP address used to perform the registration. If this is
|
address: the IP address used to perform the registration. If this is
|
||||||
None, no ratelimiting will be performed.
|
None, no ratelimiting will be performed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -553,42 +567,39 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
self.ratelimiter.ratelimit(address)
|
self.ratelimiter.ratelimit(address)
|
||||||
|
|
||||||
def register_with_store(
|
async def register_with_store(
|
||||||
self,
|
self,
|
||||||
user_id,
|
user_id: str,
|
||||||
password_hash=None,
|
password_hash: Optional[str] = None,
|
||||||
was_guest=False,
|
was_guest: bool = False,
|
||||||
make_guest=False,
|
make_guest: bool = False,
|
||||||
appservice_id=None,
|
appservice_id: Optional[str] = None,
|
||||||
create_profile_with_displayname=None,
|
create_profile_with_displayname: Optional[str] = None,
|
||||||
admin=False,
|
admin: bool = False,
|
||||||
user_type=None,
|
user_type: Optional[str] = None,
|
||||||
address=None,
|
address: Optional[str] = None,
|
||||||
shadow_banned=False,
|
shadow_banned: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""Register user in the datastore.
|
"""Register user in the datastore.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The desired user ID to register.
|
user_id: The desired user ID to register.
|
||||||
password_hash (str|None): Optional. The password hash for this user.
|
password_hash: Optional. The password hash for this user.
|
||||||
was_guest (bool): Optional. Whether this is a guest account being
|
was_guest: Optional. Whether this is a guest account being
|
||||||
upgraded to a non-guest account.
|
upgraded to a non-guest account.
|
||||||
make_guest (boolean): True if the the new user should be guest,
|
make_guest: True if the the new user should be guest,
|
||||||
false to add a regular user account.
|
false to add a regular user account.
|
||||||
appservice_id (str|None): The ID of the appservice registering the user.
|
appservice_id: The ID of the appservice registering the user.
|
||||||
create_profile_with_displayname (unicode|None): Optionally create a
|
create_profile_with_displayname: Optionally create a
|
||||||
profile for the user, setting their displayname to the given value
|
profile for the user, setting their displayname to the given value
|
||||||
admin (boolean): is an admin user?
|
admin: is an admin user?
|
||||||
user_type (str|None): type of user. One of the values from
|
user_type: type of user. One of the values from
|
||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
address (str|None): the IP address used to perform the registration.
|
address: the IP address used to perform the registration.
|
||||||
shadow_banned (bool): Whether to shadow-ban the user
|
shadow_banned: Whether to shadow-ban the user
|
||||||
|
|
||||||
Returns:
|
|
||||||
Awaitable
|
|
||||||
"""
|
"""
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
return self._register_client(
|
await self._register_client(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
was_guest=was_guest,
|
was_guest=was_guest,
|
||||||
|
@ -601,7 +612,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
shadow_banned=shadow_banned,
|
shadow_banned=shadow_banned,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.store.register_user(
|
await self.store.register_user(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
was_guest=was_guest,
|
was_guest=was_guest,
|
||||||
|
@ -614,22 +625,24 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def register_device(
|
async def register_device(
|
||||||
self, user_id, device_id, initial_display_name, is_guest=False
|
self,
|
||||||
):
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
initial_display_name: Optional[str],
|
||||||
|
is_guest: bool = False,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
"""Register a device for a user and generate an access token.
|
"""Register a device for a user and generate an access token.
|
||||||
|
|
||||||
The access token will be limited by the homeserver's session_lifetime config.
|
The access token will be limited by the homeserver's session_lifetime config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): full canonical @user:id
|
user_id: full canonical @user:id
|
||||||
device_id (str|None): The device ID to check, or None to generate
|
device_id: The device ID to check, or None to generate a new one.
|
||||||
a new one.
|
initial_display_name: An optional display name for the device.
|
||||||
initial_display_name (str|None): An optional display name for the
|
is_guest: Whether this is a guest account
|
||||||
device.
|
|
||||||
is_guest (bool): Whether this is a guest account
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[str, str]: Tuple of device ID and access token
|
Tuple of device ID and access token
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
|
@ -649,7 +662,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
||||||
|
|
||||||
device_id = await self.device_handler.check_device_registered(
|
registered_device_id = await self.device_handler.check_device_registered(
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
if is_guest:
|
if is_guest:
|
||||||
|
@ -659,20 +672,21 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
access_token = await self._auth_handler.get_access_token_for_user_id(
|
access_token = await self._auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id=device_id, valid_until_ms=valid_until_ms
|
user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
return (device_id, access_token)
|
return (registered_device_id, access_token)
|
||||||
|
|
||||||
async def post_registration_actions(self, user_id, auth_result, access_token):
|
async def post_registration_actions(
|
||||||
|
self, user_id: str, auth_result: dict, access_token: Optional[str]
|
||||||
|
) -> None:
|
||||||
"""A user has completed registration
|
"""A user has completed registration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID that consented
|
user_id: The user ID that consented
|
||||||
auth_result (dict): The authenticated credentials of the newly
|
auth_result: The authenticated credentials of the newly registered user.
|
||||||
registered user.
|
access_token: The access token of the newly logged in device, or
|
||||||
access_token (str|None): The access token of the newly logged in
|
None if `inhibit_login` enabled.
|
||||||
device, or None if `inhibit_login` enabled.
|
|
||||||
"""
|
"""
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
await self._post_registration_client(
|
await self._post_registration_client(
|
||||||
|
@ -698,19 +712,20 @@ class RegistrationHandler(BaseHandler):
|
||||||
if auth_result and LoginType.TERMS in auth_result:
|
if auth_result and LoginType.TERMS in auth_result:
|
||||||
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
|
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
|
||||||
|
|
||||||
async def _on_user_consented(self, user_id, consent_version):
|
async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
|
||||||
"""A user consented to the terms on registration
|
"""A user consented to the terms on registration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID that consented.
|
user_id: The user ID that consented.
|
||||||
consent_version (str): version of the policy the user has
|
consent_version: version of the policy the user has consented to.
|
||||||
consented to.
|
|
||||||
"""
|
"""
|
||||||
logger.info("%s has consented to the privacy policy", user_id)
|
logger.info("%s has consented to the privacy policy", user_id)
|
||||||
await self.store.user_set_consent_version(user_id, consent_version)
|
await self.store.user_set_consent_version(user_id, consent_version)
|
||||||
await self.post_consent_actions(user_id)
|
await self.post_consent_actions(user_id)
|
||||||
|
|
||||||
async def _register_email_threepid(self, user_id, threepid, token):
|
async def _register_email_threepid(
|
||||||
|
self, user_id: str, threepid: dict, token: Optional[str]
|
||||||
|
) -> None:
|
||||||
"""Add an email address as a 3pid identifier
|
"""Add an email address as a 3pid identifier
|
||||||
|
|
||||||
Also adds an email pusher for the email address, if configured in the
|
Also adds an email pusher for the email address, if configured in the
|
||||||
|
@ -719,10 +734,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
Must be called on master.
|
Must be called on master.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): id of user
|
user_id: id of user
|
||||||
threepid (object): m.login.email.identity auth response
|
threepid: m.login.email.identity auth response
|
||||||
token (str|None): access_token for the user, or None if not logged
|
token: access_token for the user, or None if not logged in.
|
||||||
in.
|
|
||||||
"""
|
"""
|
||||||
reqd = ("medium", "address", "validated_at")
|
reqd = ("medium", "address", "validated_at")
|
||||||
if any(x not in threepid for x in reqd):
|
if any(x not in threepid for x in reqd):
|
||||||
|
@ -748,6 +762,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
# up when the access token is saved, but that's quite an
|
# up when the access token is saved, but that's quite an
|
||||||
# invasive change I'd rather do separately.
|
# invasive change I'd rather do separately.
|
||||||
user_tuple = await self.store.get_user_by_access_token(token)
|
user_tuple = await self.store.get_user_by_access_token(token)
|
||||||
|
# The token better still exist.
|
||||||
|
assert user_tuple
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
await self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_pusher(
|
||||||
|
@ -762,14 +778,14 @@ class RegistrationHandler(BaseHandler):
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _register_msisdn_threepid(self, user_id, threepid):
|
async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
|
||||||
"""Add a phone number as a 3pid identifier
|
"""Add a phone number as a 3pid identifier
|
||||||
|
|
||||||
Must be called on master.
|
Must be called on master.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): id of user
|
user_id: id of user
|
||||||
threepid (object): m.login.msisdn auth response
|
threepid: m.login.msisdn auth response
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
|
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
|
||||||
|
|
|
@ -39,7 +39,7 @@ from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.iterutils import chunk_seq
|
from synapse.util.iterutils import chunk_seq
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import synapse.server
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ class Saml2SessionData:
|
||||||
|
|
||||||
|
|
||||||
class SamlHandler(BaseHandler):
|
class SamlHandler(BaseHandler):
|
||||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||||
self._saml_idp_entityid = hs.config.saml2_idp_entityid
|
self._saml_idp_entityid = hs.config.saml2_idp_entityid
|
||||||
|
@ -330,7 +330,7 @@ class SamlHandler(BaseHandler):
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
bind_emails=emails,
|
bind_emails=emails,
|
||||||
user_agent_ips=(user_agent, ip_address),
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.record_user_external_id(
|
await self.store.record_user_external_id(
|
||||||
|
|
Loading…
Reference in New Issue