Properly report user-agent/IP during registration of SSO users. (#8784)

This also expands type-hints to the SSO and registration code.

Refactors the CAS code to more closely match OIDC/SAML.
This commit is contained in:
Patrick Cloke 2020-11-23 13:28:03 -05:00 committed by GitHub
parent 7127855741
commit 6fde6aa9c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 122 deletions

1
changelog.d/8784.misc Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in v1.20.0 where the user-agent and IP address reported during user registration for CAS, OpenID Connect, and SAML were of the wrong form.

View File

@ -37,6 +37,7 @@ files =
synapse/handlers/presence.py, synapse/handlers/presence.py,
synapse/handlers/profile.py, synapse/handlers/profile.py,
synapse/handlers/read_marker.py, synapse/handlers/read_marker.py,
synapse/handlers/register.py,
synapse/handlers/room.py, synapse/handlers/room.py,
synapse/handlers/room_member.py, synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py, synapse/handlers/room_member_worker.py,

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib import urllib
from typing import Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional, Tuple
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,10 +34,10 @@ class CasHandler:
Utility class for to handle the response from a CAS SSO service. Utility class for to handle the response from a CAS SSO service.
Args: Args:
hs (synapse.server.HomeServer) hs
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self._hostname = hs.hostname self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@ -200,27 +203,57 @@ class CasHandler:
args["session"] = session args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args) username, user_display_name = await self._validate_ticket(ticket, args)
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session:
await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)
else:
if not registered_user_id:
# Pull out the user-agent and IP from the request. # Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("") user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request) ip_address = self.hs.get_ip_from_request(request)
registered_user_id = await self._registration_handler.register_user( # Get the matrix ID from the CAS username.
localpart=localpart, user_id = await self._map_cas_user_to_matrix_user(
default_display_name=user_display_name, username, user_display_name, user_agent, ip_address
user_agent_ips=(user_agent, ip_address),
) )
await self._auth_handler.complete_sso_login( if session:
registered_user_id, request, client_redirect_url await self._auth_handler.complete_sso_ui_auth(
user_id, session, request,
) )
else:
# If this not a UI auth request than there must be a redirect URL.
assert client_redirect_url
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
)
async def _map_cas_user_to_matrix_user(
self,
remote_user_id: str,
display_name: Optional[str],
user_agent: str,
ip_address: str,
) -> str:
"""
Given a CAS username, retrieve the user ID for it and possibly register the user.
Args:
remote_user_id: The username from the CAS response.
display_name: The display name from the CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
"""
localpart = map_username_to_mxid_localpart(remote_user_id)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=display_name,
user_agent_ips=[(user_agent, ip_address)],
)
return registered_user_id

View File

@ -925,7 +925,7 @@ class OidcHandler(BaseHandler):
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, localpart=localpart,
default_display_name=attributes["display_name"], default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address), user_agent_ips=[(user_agent, ip_address)],
) )
await self.store.record_user_external_id( await self.store.record_user_external_id(

View File

@ -15,10 +15,12 @@
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse import types from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.login import RegisterDeviceReplicationServlet
@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler): class RegistrationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer):
"""
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -71,7 +71,10 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime self.session_lifetime = hs.config.session_lifetime
async def check_username( async def check_username(
self, localpart, guest_access_token=None, assigned_user_id=None self,
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
): ):
if types.contains_invalid_mxid_characters(localpart): if types.contains_invalid_mxid_characters(localpart):
raise SynapseError( raise SynapseError(
@ -140,39 +143,45 @@ class RegistrationHandler(BaseHandler):
async def register_user( async def register_user(
self, self,
localpart=None, localpart: Optional[str] = None,
password_hash=None, password_hash: Optional[str] = None,
guest_access_token=None, guest_access_token: Optional[str] = None,
make_guest=False, make_guest: bool = False,
admin=False, admin: bool = False,
threepid=None, threepid: Optional[dict] = None,
user_type=None, user_type: Optional[str] = None,
default_display_name=None, default_display_name: Optional[str] = None,
address=None, address: Optional[str] = None,
bind_emails=[], bind_emails: List[str] = [],
by_admin=False, by_admin: bool = False,
user_agent_ips=None, user_agent_ips: Optional[List[Tuple[str, str]]] = None,
): ) -> str:
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:
localpart: The local part of the user ID to register. If None, localpart: The local part of the user ID to register. If None,
one will be generated. one will be generated.
password_hash (str|None): The hashed password to assign to this user so they can password_hash: The hashed password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from guest_access_token: The access token used when this was a guest
account.
make_guest: True if the the new user should be guest,
false to add a regular user account.
admin: True if the user should be registered as a server admin.
threepid: The threepid used for registering, if any.
user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname default_display_name: if set, the new user's displayname
will be set to this. Defaults to 'localpart'. will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the registration. address: the IP address used to perform the registration.
bind_emails (List[str]): list of emails to bind to this account. bind_emails: list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the by_admin: True if this registration is being made via the
admin api, otherwise False. admin api, otherwise False.
user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process. during the registration process.
Returns: Returns:
str: user_id The registere user_id.
Raises: Raises:
SynapseError if there was a problem registering. SynapseError if there was a problem registering.
""" """
@ -236,8 +245,10 @@ class RegistrationHandler(BaseHandler):
else: else:
# autogen a sequential user ID # autogen a sequential user ID
fail_count = 0 fail_count = 0
user = None # If a default display name is not given, generate one.
while not user: generate_display_name = default_display_name is None
# This breaks on successful registration *or* errors after 10 failures.
while True:
# Fail after being unable to find a suitable ID a few times # Fail after being unable to find a suitable ID a few times
if fail_count > 10: if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID") raise SynapseError(500, "Unable to find a suitable guest user ID")
@ -246,7 +257,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id) self.check_user_id_not_appservice_exclusive(user_id)
if default_display_name is None: if generate_display_name:
default_display_name = localpart default_display_name = localpart
try: try:
await self.register_with_store( await self.register_with_store(
@ -262,8 +273,6 @@ class RegistrationHandler(BaseHandler):
break break
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user = None
user_id = None
fail_count += 1 fail_count += 1
if not self.hs.config.user_consent_at_registration: if not self.hs.config.user_consent_at_registration:
@ -295,7 +304,7 @@ class RegistrationHandler(BaseHandler):
return user_id return user_id
async def _create_and_join_rooms(self, user_id: str): async def _create_and_join_rooms(self, user_id: str) -> None:
""" """
Create the auto-join rooms and join or invite the user to them. Create the auto-join rooms and join or invite the user to them.
@ -379,7 +388,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e: except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e)
async def _join_rooms(self, user_id: str): async def _join_rooms(self, user_id: str) -> None:
""" """
Join or invite the user to the auto-join rooms. Join or invite the user to the auto-join rooms.
@ -425,6 +434,9 @@ class RegistrationHandler(BaseHandler):
# Send the invite, if necessary. # Send the invite, if necessary.
if requires_invite: if requires_invite:
# If an invite is required, there must be a auto-join user ID.
assert self.hs.config.registration.auto_join_user_id
await room_member_handler.update_membership( await room_member_handler.update_membership(
requester=create_requester( requester=create_requester(
self.hs.config.registration.auto_join_user_id, self.hs.config.registration.auto_join_user_id,
@ -456,7 +468,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e: except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e)
async def _auto_join_rooms(self, user_id: str): async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place """Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created. if the user is the first to be created.
@ -479,16 +491,16 @@ class RegistrationHandler(BaseHandler):
else: else:
await self._join_rooms(user_id) await self._join_rooms(user_id)
async def post_consent_actions(self, user_id): async def post_consent_actions(self, user_id: str) -> None:
"""A series of registration actions that can only be carried out once consent """A series of registration actions that can only be carried out once consent
has been granted has been granted
Args: Args:
user_id (str): The user to join user_id: The user to join
""" """
await self._auto_join_rooms(user_id) await self._auto_join_rooms(user_id)
async def appservice_register(self, user_localpart, as_token): async def appservice_register(self, user_localpart: str, as_token: str) -> str:
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
@ -513,7 +525,9 @@ class RegistrationHandler(BaseHandler):
) )
return user_id return user_id
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): def check_user_id_not_appservice_exclusive(
self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
) -> None:
# don't allow people to register the server notices mxid # don't allow people to register the server notices mxid
if self._server_notices_mxid is not None: if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid: if user_id == self._server_notices_mxid:
@ -537,12 +551,12 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
def check_registration_ratelimit(self, address): def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit """A simple helper method to check whether the registration rate limit has been hit
for a given IP address for a given IP address
Args: Args:
address (str|None): the IP address used to perform the registration. If this is address: the IP address used to perform the registration. If this is
None, no ratelimiting will be performed. None, no ratelimiting will be performed.
Raises: Raises:
@ -553,42 +567,39 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter.ratelimit(address) self.ratelimiter.ratelimit(address)
def register_with_store( async def register_with_store(
self, self,
user_id, user_id: str,
password_hash=None, password_hash: Optional[str] = None,
was_guest=False, was_guest: bool = False,
make_guest=False, make_guest: bool = False,
appservice_id=None, appservice_id: Optional[str] = None,
create_profile_with_displayname=None, create_profile_with_displayname: Optional[str] = None,
admin=False, admin: bool = False,
user_type=None, user_type: Optional[str] = None,
address=None, address: Optional[str] = None,
shadow_banned=False, shadow_banned: bool = False,
): ) -> None:
"""Register user in the datastore. """Register user in the datastore.
Args: Args:
user_id (str): The desired user ID to register. user_id: The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user. password_hash: Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being was_guest: Optional. Whether this is a guest account being
upgraded to a non-guest account. upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest, make_guest: True if the the new user should be guest,
false to add a regular user account. false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user. appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a create_profile_with_displayname: Optionally create a
profile for the user, setting their displayname to the given value profile for the user, setting their displayname to the given value
admin (boolean): is an admin user? admin: is an admin user?
user_type (str|None): type of user. One of the values from user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the registration. address: the IP address used to perform the registration.
shadow_banned (bool): Whether to shadow-ban the user shadow_banned: Whether to shadow-ban the user
Returns:
Awaitable
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
return self._register_client( await self._register_client(
user_id=user_id, user_id=user_id,
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
@ -601,7 +612,7 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned, shadow_banned=shadow_banned,
) )
else: else:
return self.store.register_user( await self.store.register_user(
user_id=user_id, user_id=user_id,
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
@ -614,22 +625,24 @@ class RegistrationHandler(BaseHandler):
) )
async def register_device( async def register_device(
self, user_id, device_id, initial_display_name, is_guest=False self,
): user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token. """Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config. The access token will be limited by the homeserver's session_lifetime config.
Args: Args:
user_id (str): full canonical @user:id user_id: full canonical @user:id
device_id (str|None): The device ID to check, or None to generate device_id: The device ID to check, or None to generate a new one.
a new one. initial_display_name: An optional display name for the device.
initial_display_name (str|None): An optional display name for the is_guest: Whether this is a guest account
device.
is_guest (bool): Whether this is a guest account
Returns: Returns:
tuple[str, str]: Tuple of device ID and access token Tuple of device ID and access token
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
@ -649,7 +662,7 @@ class RegistrationHandler(BaseHandler):
) )
valid_until_ms = self.clock.time_msec() + self.session_lifetime valid_until_ms = self.clock.time_msec() + self.session_lifetime
device_id = await self.device_handler.check_device_registered( registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
if is_guest: if is_guest:
@ -659,20 +672,21 @@ class RegistrationHandler(BaseHandler):
) )
else: else:
access_token = await self._auth_handler.get_access_token_for_user_id( access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
) )
return (device_id, access_token) return (registered_device_id, access_token)
async def post_registration_actions(self, user_id, auth_result, access_token): async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
) -> None:
"""A user has completed registration """A user has completed registration
Args: Args:
user_id (str): The user ID that consented user_id: The user ID that consented
auth_result (dict): The authenticated credentials of the newly auth_result: The authenticated credentials of the newly registered user.
registered user. access_token: The access token of the newly logged in device, or
access_token (str|None): The access token of the newly logged in None if `inhibit_login` enabled.
device, or None if `inhibit_login` enabled.
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
await self._post_registration_client( await self._post_registration_client(
@ -698,19 +712,20 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.TERMS in auth_result: if auth_result and LoginType.TERMS in auth_result:
await self._on_user_consented(user_id, self.hs.config.user_consent_version) await self._on_user_consented(user_id, self.hs.config.user_consent_version)
async def _on_user_consented(self, user_id, consent_version): async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration """A user consented to the terms on registration
Args: Args:
user_id (str): The user ID that consented. user_id: The user ID that consented.
consent_version (str): version of the policy the user has consent_version: version of the policy the user has consented to.
consented to.
""" """
logger.info("%s has consented to the privacy policy", user_id) logger.info("%s has consented to the privacy policy", user_id)
await self.store.user_set_consent_version(user_id, consent_version) await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id) await self.post_consent_actions(user_id)
async def _register_email_threepid(self, user_id, threepid, token): async def _register_email_threepid(
self, user_id: str, threepid: dict, token: Optional[str]
) -> None:
"""Add an email address as a 3pid identifier """Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the Also adds an email pusher for the email address, if configured in the
@ -719,10 +734,9 @@ class RegistrationHandler(BaseHandler):
Must be called on master. Must be called on master.
Args: Args:
user_id (str): id of user user_id: id of user
threepid (object): m.login.email.identity auth response threepid: m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged token: access_token for the user, or None if not logged in.
in.
""" """
reqd = ("medium", "address", "validated_at") reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd): if any(x not in threepid for x in reqd):
@ -748,6 +762,8 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an # up when the access token is saved, but that's quite an
# invasive change I'd rather do separately. # invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token) user_tuple = await self.store.get_user_by_access_token(token)
# The token better still exist.
assert user_tuple
token_id = user_tuple.token_id token_id = user_tuple.token_id
await self.pusher_pool.add_pusher( await self.pusher_pool.add_pusher(
@ -762,14 +778,14 @@ class RegistrationHandler(BaseHandler):
data={}, data={},
) )
async def _register_msisdn_threepid(self, user_id, threepid): async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
"""Add a phone number as a 3pid identifier """Add a phone number as a 3pid identifier
Must be called on master. Must be called on master.
Args: Args:
user_id (str): id of user user_id: id of user
threepid (object): m.login.msisdn auth response threepid: m.login.msisdn auth response
""" """
try: try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) assert_params_in_dict(threepid, ["medium", "address", "validated_at"])

View File

@ -39,7 +39,7 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING: if TYPE_CHECKING:
import synapse.server from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class Saml2SessionData:
class SamlHandler(BaseHandler): class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid self._saml_idp_entityid = hs.config.saml2_idp_entityid
@ -330,7 +330,7 @@ class SamlHandler(BaseHandler):
localpart=localpart, localpart=localpart,
default_display_name=displayname, default_display_name=displayname,
bind_emails=emails, bind_emails=emails,
user_agent_ips=(user_agent, ip_address), user_agent_ips=[(user_agent, ip_address)],
) )
await self.store.record_user_external_id( await self.store.record_user_external_id(