SSO: redirect to public URL before setting cookies (#9436)

... otherwise, we don't get the cookie back.
This commit is contained in:
Richard van der Hoff 2021-02-26 14:02:06 +00:00 committed by GitHub
parent e53f11bd62
commit 15090de850
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 28 deletions

1
changelog.d/9436.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug in single sign-on which could cause a "No session cookie found" error.

View File

@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import Union
from twisted.internet import task from twisted.internet import address, task
from twisted.web.client import FileBodyProducer from twisted.web.client import FileBodyProducer
from twisted.web.iweb import IRequest from twisted.web.iweb import IRequest
@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
pass pass
def get_request_uri(request: IRequest) -> bytes:
"""Return the full URI that was requested by the client"""
return b"%s://%s%s" % (
b"https" if request.isSecure() else b"http",
_get_requested_host(request),
# despite its name, "request.uri" is only the path and query-string.
request.uri,
)
def _get_requested_host(request: IRequest) -> bytes:
hostname = request.getHeader(b"host")
if hostname:
return hostname
# no Host header, use the address/port that the request arrived on
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
hostname = host.host.encode("ascii")
if request.isSecure() and host.port == 443:
# default port for https
return hostname
if not request.isSecure() and host.port == 80:
# default port for http
return hostname
return b"%s:%i" % (
hostname,
host.port,
)
def get_request_user_agent(request: IRequest, default: str = "") -> str: def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default.""" """Return the last User-Agent header, or the given default."""
# There could be raw utf-8 bytes in the User-Agent header. # There could be raw utf-8 bytes in the User-Agent header.

View File

@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet):
hs.get_oidc_handler() hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
def register(self, http_server: HttpServer) -> None: def register(self, http_server: HttpServer) -> None:
super().register(http_server) super().register(http_server)
@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None: ) -> None:
if not self._public_baseurl:
raise SynapseError(400, "SSO requires a valid public_baseurl")
# if this isn't the expected hostname, redirect to the right one, so that we
# get our cookies back.
requested_uri = get_request_uri(request)
baseurl_bytes = self._public_baseurl.encode("utf-8")
if not requested_uri.startswith(baseurl_bytes):
# swap out the incorrect base URL for the right one.
#
# The idea here is to redirect from
# https://foo.bar/whatever/_matrix/...
# to
# https://public.baseurl/_matrix/...
#
i = requested_uri.index(b"/_matrix")
new_uri = baseurl_bytes[:-1] + requested_uri[i:]
logger.info(
"Requested URI %s is not canonical: redirecting to %s",
requested_uri.decode("utf-8", errors="replace"),
new_uri.decode("utf-8", errors="replace"),
)
request.redirect(new_uri)
finish_request(request)
return
client_redirect_url = parse_string( client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None request, "redirectUrl", required=True, encoding=None
) )

View File

@ -15,7 +15,7 @@
import time import time
import urllib.parse import urllib.parse
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlencode from urllib.parse import urlencode
from mock import Mock from mock import Mock
@ -47,8 +47,14 @@ except ImportError:
HAS_JWT = False HAS_JWT = False
# public_base_url used in some tests # synapse server name: used to populate public_baseurl in some tests
BASE_URL = "https://synapse/" SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
# public_baseurl for some tests. It uses an http:// scheme because
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
# https://....
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
# CAS server used in some tests # CAS server used in some tests
CAS_SERVER = "https://fake.test" CAS_SERVER = "https://fake.test"
@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_multi_sso_redirect(self): def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
channel = self.make_request( channel = self._make_sso_redirect_request(False, None)
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0] uri = channel.headers.getRawHeaders("Location")[0]
@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_client_idp_redirect_msc2858_disabled(self): def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self.make_request( channel = self._make_sso_redirect_request(True, "oidc")
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
@override_config({"experimental_features": {"msc2858_enabled": True}}) @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self): def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
channel = self.make_request( channel = self._make_sso_redirect_request(True, "xxx")
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
@override_config({"experimental_features": {"msc2858_enabled": True}}) @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self): def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it""" """If the client pick a known IdP, redirect to it"""
channel = self.make_request( channel = self._make_sso_redirect_request(True, "oidc")
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0] oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server # it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):
"""Send a request to /_matrix/client/r0/login/sso/redirect
... or the unstable equivalent
... possibly specifying an IDP provider
"""
endpoint = (
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect"
if unstable_endpoint
else "/_matrix/client/r0/login/sso/redirect"
)
if idp_prov is not None:
endpoint += "/" + idp_prov
endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
return self.make_request(
"GET",
endpoint,
custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
)
@staticmethod @staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = " prefix = key + " = "

View File

@ -542,13 +542,30 @@ class RestHelper:
if client_redirect_url: if client_redirect_url:
params["redirectUrl"] = client_redirect_url params["redirectUrl"] = client_redirect_url
# hit the redirect url (which will issue a cookie and state) # hit the redirect url (which should redirect back to the redirect url. This
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
self.site, self.site,
"GET", "GET",
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
) )
assert channel.code == 302
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
location = channel.headers.getRawHeaders("Location")[0]
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
custom_headers=[
("Host", parts[1]),
],
)
assert channel.code == 302 assert channel.code == 302
channel.extract_cookies(cookies) channel.extract_cookies(cookies)

View File

@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
def default_config(self): def default_config(self):
config = super().default_config() config = super().default_config()
config["public_baseurl"] = "https://synapse.test"
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
# False, so synapse will see the requested uri as http://..., so using http in
# the public_baseurl stops Synapse trying to redirect to https.
config["public_baseurl"] = "http://synapse.test"
if HAS_OIDC: if HAS_OIDC:
# we enable OIDC as a way of testing SSO flows # we enable OIDC as a way of testing SSO flows

View File

@ -124,7 +124,11 @@ class FakeChannel:
return address.IPv4Address("TCP", self._ip, 3423) return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self): def getHost(self):
return None # this is called by Request.__init__ to configure Request.host.
return address.IPv4Address("TCP", "127.0.0.1", 8888)
def isSecure(self):
return False
@property @property
def transport(self): def transport(self):