Add support for stable MSC2858 API (#9617)
The stable format uses different brand identifiers, so we need to support two identifiers for each IdP.
This commit is contained in:
parent
5b5bc188cf
commit
dd69110d95
|
@ -0,0 +1 @@
|
||||||
|
Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)).
|
|
@ -226,7 +226,7 @@ Synapse config:
|
||||||
oidc_providers:
|
oidc_providers:
|
||||||
- idp_id: github
|
- idp_id: github
|
||||||
idp_name: Github
|
idp_name: Github
|
||||||
idp_brand: "org.matrix.github" # optional: styling hint for clients
|
idp_brand: "github" # optional: styling hint for clients
|
||||||
discover: false
|
discover: false
|
||||||
issuer: "https://github.com/"
|
issuer: "https://github.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
|
@ -252,7 +252,7 @@ oidc_providers:
|
||||||
oidc_providers:
|
oidc_providers:
|
||||||
- idp_id: google
|
- idp_id: google
|
||||||
idp_name: Google
|
idp_name: Google
|
||||||
idp_brand: "org.matrix.google" # optional: styling hint for clients
|
idp_brand: "google" # optional: styling hint for clients
|
||||||
issuer: "https://accounts.google.com/"
|
issuer: "https://accounts.google.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
@ -299,7 +299,7 @@ Synapse config:
|
||||||
oidc_providers:
|
oidc_providers:
|
||||||
- idp_id: gitlab
|
- idp_id: gitlab
|
||||||
idp_name: Gitlab
|
idp_name: Gitlab
|
||||||
idp_brand: "org.matrix.gitlab" # optional: styling hint for clients
|
idp_brand: "gitlab" # optional: styling hint for clients
|
||||||
issuer: "https://gitlab.com/"
|
issuer: "https://gitlab.com/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
client_secret: "your-client-secret" # TO BE FILLED
|
client_secret: "your-client-secret" # TO BE FILLED
|
||||||
|
@ -334,7 +334,7 @@ Synapse config:
|
||||||
```yaml
|
```yaml
|
||||||
- idp_id: facebook
|
- idp_id: facebook
|
||||||
idp_name: Facebook
|
idp_name: Facebook
|
||||||
idp_brand: "org.matrix.facebook" # optional: styling hint for clients
|
idp_brand: "facebook" # optional: styling hint for clients
|
||||||
discover: false
|
discover: false
|
||||||
issuer: "https://facebook.com"
|
issuer: "https://facebook.com"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
|
|
|
@ -1919,7 +1919,7 @@ oidc_providers:
|
||||||
#
|
#
|
||||||
#- idp_id: github
|
#- idp_id: github
|
||||||
# idp_name: Github
|
# idp_name: Github
|
||||||
# idp_brand: org.matrix.github
|
# idp_brand: github
|
||||||
# discover: false
|
# discover: false
|
||||||
# issuer: "https://github.com/"
|
# issuer: "https://github.com/"
|
||||||
# client_id: "your-client-id" # TO BE FILLED
|
# client_id: "your-client-id" # TO BE FILLED
|
||||||
|
|
|
@ -237,7 +237,7 @@ class OIDCConfig(Config):
|
||||||
#
|
#
|
||||||
#- idp_id: github
|
#- idp_id: github
|
||||||
# idp_name: Github
|
# idp_name: Github
|
||||||
# idp_brand: org.matrix.github
|
# idp_brand: github
|
||||||
# discover: false
|
# discover: false
|
||||||
# issuer: "https://github.com/"
|
# issuer: "https://github.com/"
|
||||||
# client_id: "your-client-id" # TO BE FILLED
|
# client_id: "your-client-id" # TO BE FILLED
|
||||||
|
@ -272,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"idp_icon": {"type": "string"},
|
"idp_icon": {"type": "string"},
|
||||||
"idp_brand": {
|
"idp_brand": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
# MSC2758-style namespaced identifier
|
"minLength": 1,
|
||||||
|
"maxLength": 255,
|
||||||
|
"pattern": "^[a-z][a-z0-9_.-]*$",
|
||||||
|
},
|
||||||
|
"idp_unstable_brand": {
|
||||||
|
"type": "string",
|
||||||
"minLength": 1,
|
"minLength": 1,
|
||||||
"maxLength": 255,
|
"maxLength": 255,
|
||||||
"pattern": "^[a-z][a-z0-9_.-]*$",
|
"pattern": "^[a-z][a-z0-9_.-]*$",
|
||||||
|
@ -466,6 +471,7 @@ def _parse_oidc_config_dict(
|
||||||
idp_name=oidc_config.get("idp_name", "OIDC"),
|
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||||
idp_icon=idp_icon,
|
idp_icon=idp_icon,
|
||||||
idp_brand=oidc_config.get("idp_brand"),
|
idp_brand=oidc_config.get("idp_brand"),
|
||||||
|
unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
|
||||||
discover=oidc_config.get("discover", True),
|
discover=oidc_config.get("discover", True),
|
||||||
issuer=oidc_config["issuer"],
|
issuer=oidc_config["issuer"],
|
||||||
client_id=oidc_config["client_id"],
|
client_id=oidc_config["client_id"],
|
||||||
|
@ -512,6 +518,9 @@ class OidcProviderConfig:
|
||||||
# Optional brand identifier for this IdP.
|
# Optional brand identifier for this IdP.
|
||||||
idp_brand = attr.ib(type=Optional[str])
|
idp_brand = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
|
# Optional brand identifier for the unstable API (see MSC2858).
|
||||||
|
unstable_idp_brand = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||||
discover = attr.ib(type=bool)
|
discover = attr.ib(type=bool)
|
||||||
|
|
||||||
|
|
|
@ -83,6 +83,7 @@ class CasHandler:
|
||||||
# the SsoIdentityProvider protocol type.
|
# the SsoIdentityProvider protocol type.
|
||||||
self.idp_icon = None
|
self.idp_icon = None
|
||||||
self.idp_brand = None
|
self.idp_brand = None
|
||||||
|
self.unstable_idp_brand = None
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
|
|
@ -330,6 +330,9 @@ class OidcProvider:
|
||||||
# optional brand identifier for this auth provider
|
# optional brand identifier for this auth provider
|
||||||
self.idp_brand = provider.idp_brand
|
self.idp_brand = provider.idp_brand
|
||||||
|
|
||||||
|
# Optional brand identifier for the unstable API (see MSC2858).
|
||||||
|
self.unstable_idp_brand = provider.unstable_idp_brand
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
self._sso_handler.register_identity_provider(self)
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
|
@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
|
||||||
# the SsoIdentityProvider protocol type.
|
# the SsoIdentityProvider protocol type.
|
||||||
self.idp_icon = None
|
self.idp_icon = None
|
||||||
self.idp_brand = None
|
self.idp_brand = None
|
||||||
|
self.unstable_idp_brand = None
|
||||||
|
|
||||||
# a map from saml session id to Saml2SessionData object
|
# a map from saml session id to Saml2SessionData object
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||||
|
|
|
@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
|
||||||
"""Optional branding identifier"""
|
"""Optional branding identifier"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unstable_idp_brand(self) -> Optional[str]:
|
||||||
|
"""Optional brand identifier for the unstable API (see MSC2858)."""
|
||||||
|
return None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def handle_redirect_request(
|
async def handle_redirect_request(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -14,10 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
from synapse.api.urls import CLIENT_API_PREFIX
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.sso import SsoIdentityProvider
|
from synapse.handlers.sso import SsoIdentityProvider
|
||||||
from synapse.http import get_request_uri
|
from synapse.http import get_request_uri
|
||||||
|
@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
|
||||||
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
||||||
|
|
||||||
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
||||||
sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
|
sso_flow = {
|
||||||
|
"type": LoginRestServlet.SSO_TYPE,
|
||||||
|
"identity_providers": [
|
||||||
|
_get_auth_flow_dict_for_idp(
|
||||||
|
idp,
|
||||||
|
)
|
||||||
|
for idp in self._sso_handler.get_identity_providers().values()
|
||||||
|
],
|
||||||
|
} # type: JsonDict
|
||||||
|
|
||||||
if self._msc2858_enabled:
|
if self._msc2858_enabled:
|
||||||
|
# backwards-compatibility support for clients which don't
|
||||||
|
# support the stable API yet
|
||||||
sso_flow["org.matrix.msc2858.identity_providers"] = [
|
sso_flow["org.matrix.msc2858.identity_providers"] = [
|
||||||
_get_auth_flow_dict_for_idp(idp)
|
_get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
|
||||||
for idp in self._sso_handler.get_identity_providers().values()
|
for idp in self._sso_handler.get_identity_providers().values()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
|
def _get_auth_flow_dict_for_idp(
|
||||||
|
idp: SsoIdentityProvider, use_unstable_brands: bool = False
|
||||||
|
) -> JsonDict:
|
||||||
"""Return an entry for the login flow dict
|
"""Return an entry for the login flow dict
|
||||||
|
|
||||||
Returns an entry suitable for inclusion in "identity_providers" in the
|
Returns an entry suitable for inclusion in "identity_providers" in the
|
||||||
response to GET /_matrix/client/r0/login
|
response to GET /_matrix/client/r0/login
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idp: the identity provider to describe
|
||||||
|
use_unstable_brands: whether we should use brand identifiers suitable
|
||||||
|
for the unstable API
|
||||||
"""
|
"""
|
||||||
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
|
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
|
||||||
if idp.idp_icon:
|
if idp.idp_icon:
|
||||||
e["icon"] = idp.idp_icon
|
e["icon"] = idp.idp_icon
|
||||||
if idp.idp_brand:
|
if idp.idp_brand:
|
||||||
e["brand"] = idp.idp_brand
|
e["brand"] = idp.idp_brand
|
||||||
|
# use the stable brand identifier if the unstable identifier isn't defined.
|
||||||
|
if use_unstable_brands and idp.unstable_idp_brand:
|
||||||
|
e["brand"] = idp.unstable_idp_brand
|
||||||
return e
|
return e
|
||||||
|
|
||||||
|
|
||||||
class SsoRedirectServlet(RestServlet):
|
class SsoRedirectServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
|
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
|
||||||
|
re.compile(
|
||||||
|
"^"
|
||||||
|
+ CLIENT_API_PREFIX
|
||||||
|
+ "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
# make sure that the relevant handlers are instantiated, so that they
|
# make sure that the relevant handlers are instantiated, so that they
|
||||||
|
@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
|
||||||
def register(self, http_server: HttpServer) -> None:
|
def register(self, http_server: HttpServer) -> None:
|
||||||
super().register(http_server)
|
super().register(http_server)
|
||||||
if self._msc2858_enabled:
|
if self._msc2858_enabled:
|
||||||
# expose additional endpoint for MSC2858 support
|
# expose additional endpoint for MSC2858 support: backwards-compat support
|
||||||
|
# for clients which don't yet support the stable endpoints.
|
||||||
http_server.register_paths(
|
http_server.register_paths(
|
||||||
"GET",
|
"GET",
|
||||||
client_patterns(
|
client_patterns(
|
||||||
|
|
|
@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
expected_flows = [
|
expected_flow_types = [
|
||||||
{"type": "m.login.cas"},
|
"m.login.cas",
|
||||||
{"type": "m.login.sso"},
|
"m.login.sso",
|
||||||
{"type": "m.login.token"},
|
"m.login.token",
|
||||||
{"type": "m.login.password"},
|
"m.login.password",
|
||||||
] + ADDITIONAL_LOGIN_FLOWS
|
] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
|
||||||
|
|
||||||
self.assertCountEqual(channel.json_body["flows"], expected_flows)
|
self.assertCountEqual(
|
||||||
|
[f["type"] for f in channel.json_body["flows"]], expected_flow_types
|
||||||
|
)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
||||||
def test_get_msc2858_login_flows(self):
|
def test_get_msc2858_login_flows(self):
|
||||||
|
@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
|
||||||
def test_client_idp_redirect_msc2858_disabled(self):
|
|
||||||
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
|
|
||||||
channel = self._make_sso_redirect_request(True, "oidc")
|
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
|
||||||
def test_client_idp_redirect_to_unknown(self):
|
def test_client_idp_redirect_to_unknown(self):
|
||||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||||
channel = self._make_sso_redirect_request(True, "xxx")
|
channel = self._make_sso_redirect_request(False, "xxx")
|
||||||
self.assertEqual(channel.code, 404, channel.result)
|
self.assertEqual(channel.code, 404, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
|
||||||
def test_client_idp_redirect_to_oidc(self):
|
def test_client_idp_redirect_to_oidc(self):
|
||||||
"""If the client pick a known IdP, redirect to it"""
|
"""If the client pick a known IdP, redirect to it"""
|
||||||
|
channel = self._make_sso_redirect_request(False, "oidc")
|
||||||
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
|
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||||
|
|
||||||
|
# it should redirect us to the auth page of the OIDC server
|
||||||
|
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||||
|
|
||||||
|
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
||||||
|
def test_client_msc2858_redirect_to_oidc(self):
|
||||||
|
"""Test the unstable API"""
|
||||||
channel = self._make_sso_redirect_request(True, "oidc")
|
channel = self._make_sso_redirect_request(True, "oidc")
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
# it should redirect us to the auth page of the OIDC server
|
# it should redirect us to the auth page of the OIDC server
|
||||||
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||||
|
|
||||||
|
def test_client_idp_redirect_msc2858_disabled(self):
|
||||||
|
"""If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
|
||||||
|
channel = self._make_sso_redirect_request(True, "oidc")
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
||||||
|
|
||||||
def _make_sso_redirect_request(
|
def _make_sso_redirect_request(
|
||||||
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
|
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue