Split out a separate endpoint to complete SSO registration (#9262)

There are going to be a couple of paths to get to the final step of SSO reg, and I want the URL in the browser to consistent. So, let's move the final step onto a separate path, which we redirect to.
This commit is contained in:
Richard van der Hoff 2021-02-01 13:15:51 +00:00 committed by GitHub
parent a083aea396
commit f78d07bf00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 145 additions and 26 deletions

1
changelog.d/9262.feature Normal file
View File

@ -0,0 +1 @@
Improve the user experience of setting up an account via single-sign on.

View File

@ -62,6 +62,7 @@ from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.synapse.client.sso_register import SsoRegisterResource
from synapse.rest.well_known import WellKnownResource from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer):
"/_synapse/admin": AdminRestResource(self), "/_synapse/admin": AdminRestResource(self),
"/_synapse/client/pick_username": pick_username_resource(self), "/_synapse/client/pick_username": pick_username_resource(self),
"/_synapse/client/pick_idp": PickIdpResource(self), "/_synapse/client/pick_idp": PickIdpResource(self),
"/_synapse/client/sso_register": SsoRegisterResource(self),
} }
) )

View File

@ -21,12 +21,13 @@ import attr
from typing_extensions import NoReturn, Protocol from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request from twisted.web.http import Request
from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -141,6 +142,9 @@ class UsernameMappingSession:
# expiry time for the session, in milliseconds # expiry time for the session, in milliseconds
expiry_time_ms = attr.ib(type=int) expiry_time_ms = attr.ib(type=int)
# choices made by the user
chosen_localpart = attr.ib(type=Optional[str], default=None)
# the HTTP cookie used to track the mapping session id # the HTTP cookie used to track the mapping session id
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@ -647,6 +651,25 @@ class SsoHandler:
) )
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
"""Look up the given username mapping session
If it is not found, raises a SynapseError with an http code of 400
Args:
session_id: session to look up
Returns:
active mapping session
Raises:
SynapseError if the session is not found/has expired
"""
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if session:
return session
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
async def check_username_availability( async def check_username_availability(
self, localpart: str, session_id: str, self, localpart: str, session_id: str,
) -> bool: ) -> bool:
@ -663,12 +686,7 @@ class SsoHandler:
# make sure that there is a valid mapping session, to stop people dictionary- # make sure that there is a valid mapping session, to stop people dictionary-
# scanning for accounts # scanning for accounts
self.get_mapping_session(session_id)
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
logger.info( logger.info(
"[session %s] Checking for availability of username %s", "[session %s] Checking for availability of username %s",
@ -696,16 +714,33 @@ class SsoHandler:
localpart: localpart requested by the user localpart: localpart requested by the user
session_id: ID of the username mapping session, extracted from a cookie session_id: ID of the username mapping session, extracted from a cookie
""" """
self._expire_old_sessions() session = self.get_mapping_session(session_id)
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
logger.info("[session %s] Registering localpart %s", session_id, localpart) # update the session with the user's choices
session.chosen_localpart = localpart
# we're done; now we can register the user
respond_with_redirect(request, b"/_synapse/client/sso_register")
async def register_sso_user(self, request: Request, session_id: str) -> None:
"""Called once we have all the info we need to register a new user.
Does so and serves an HTTP response
Args:
request: HTTP request
session_id: ID of the username mapping session, extracted from a cookie
"""
session = self.get_mapping_session(session_id)
logger.info(
"[session %s] Registering localpart %s",
session_id,
session.chosen_localpart,
)
attributes = UserAttributes( attributes = UserAttributes(
localpart=localpart, localpart=session.chosen_localpart,
display_name=session.display_name, display_name=session.display_name,
emails=session.emails, emails=session.emails,
) )
@ -720,7 +755,12 @@ class SsoHandler:
request.getClientIP(), request.getClientIP(),
) )
logger.info("[session %s] Registered userid %s", session_id, user_id) logger.info(
"[session %s] Registered userid %s with attributes %s",
session_id,
user_id,
attributes,
)
# delete the mapping session and the cookie # delete the mapping session and the cookie
del self._username_mapping_sessions[session_id] del self._username_mapping_sessions[session_id]
@ -751,3 +791,14 @@ class SsoHandler:
for session_id in to_expire: for session_id in to_expire:
logger.info("Expiring mapping session %s", session_id) logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id] del self._username_mapping_sessions[session_id]
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
Raises a SynapseError if the cookie isn't found
"""
session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")

View File

@ -761,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request):
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
def respond_with_redirect(request: Request, url: bytes) -> None:
"""Write a 302 response to the request, if it is still alive."""
logger.debug("Redirect to %s", url.decode("utf-8"))
request.redirect(url)
finish_request(request)
def finish_request(request: Request): def finish_request(request: Request):
""" Finish writing the response to the request. """ Finish writing the response to the request.

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import pkg_resources import pkg_resources
@ -20,8 +21,7 @@ from twisted.web.http import Request
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -61,12 +61,10 @@ class AvailabilityCheckResource(DirectServeJsonResource):
async def _async_render_GET(self, request: Request): async def _async_render_GET(self, request: Request):
localpart = parse_string(request, "username", required=True) localpart = parse_string(request, "username", required=True)
session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) session_id = get_username_mapping_session_cookie_from_request(request)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
is_available = await self._sso_handler.check_username_availability( is_available = await self._sso_handler.check_username_availability(
localpart, session_id.decode("ascii", errors="replace") localpart, session_id
) )
return 200, {"available": is_available} return 200, {"available": is_available}
@ -79,10 +77,8 @@ class SubmitResource(DirectServeHtmlResource):
async def _async_render_POST(self, request: SynapseRequest): async def _async_render_POST(self, request: SynapseRequest):
localpart = parse_string(request, "username", required=True) localpart = parse_string(request, "username", required=True)
session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) session_id = get_username_mapping_session_cookie_from_request(request)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
await self._sso_handler.handle_submit_username_request( await self._sso_handler.handle_submit_username_request(
request, localpart, session_id.decode("ascii", errors="replace") request, localpart, session_id
) )

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
from synapse.http.server import DirectServeHtmlResource
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class SsoRegisterResource(DirectServeHtmlResource):
"""A resource which completes SSO registration
This resource gets mounted at /_synapse/client/sso_register, and is shown
after we collect username and/or consent for a new SSO user. It (finally) registers
the user, and confirms redirect to the client
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request: Request) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
logger.warning("Error fetching session cookie: %s", e)
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
return
await self._sso_handler.register_sso_user(request, session_id)

View File

@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.synapse.client.sso_register import SsoRegisterResource
from synapse.types import create_requester from synapse.types import create_requester
from tests import unittest from tests import unittest
@ -1215,6 +1216,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
d = super().create_resource_dict() d = super().create_resource_dict()
d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs)
d["/_synapse/oidc"] = OIDCResource(self.hs) d["/_synapse/oidc"] = OIDCResource(self.hs)
return d return d
@ -1253,7 +1255,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
# Now, submit a username to the username picker, which should serve a redirect # Now, submit a username to the username picker, which should serve a redirect
# back to the client # to the completion page
submit_path = picker_url + "/submit" submit_path = picker_url + "/submit"
content = urlencode({b"username": b"bobby"}).encode("utf8") content = urlencode({b"username": b"bobby"}).encode("utf8")
chan = self.make_request( chan = self.make_request(
@ -1270,6 +1272,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
) )
self.assertEqual(chan.code, 302, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
# send a request to the completion page, which should 302 to the client redirectUrl
chan = self.make_request(
"GET",
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
# ensure that the returned location matches the requested redirect URL # ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1) path, query = location_headers[0].split("?", 1)
self.assertEqual(path, "https://x") self.assertEqual(path, "https://x")