Combine the CAS & SAML implementations for required attributes. (#9326)
This commit is contained in:
parent
80d6dc9783
commit
6dade80048
|
@ -0,0 +1 @@
|
||||||
|
Share the code for handling required attributes between the CAS and SAML handlers.
|
|
@ -13,7 +13,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from synapse.config.sso import SsoAttributeRequirement
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
from ._util import validate_config
|
||||||
|
|
||||||
|
|
||||||
class CasConfig(Config):
|
class CasConfig(Config):
|
||||||
|
@ -38,12 +43,16 @@ class CasConfig(Config):
|
||||||
public_base_url + "_matrix/client/r0/login/cas/ticket"
|
public_base_url + "_matrix/client/r0/login/cas/ticket"
|
||||||
)
|
)
|
||||||
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
|
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
|
||||||
self.cas_required_attributes = cas_config.get("required_attributes") or {}
|
required_attributes = cas_config.get("required_attributes") or {}
|
||||||
|
self.cas_required_attributes = _parsed_required_attributes_def(
|
||||||
|
required_attributes
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.cas_server_url = None
|
self.cas_server_url = None
|
||||||
self.cas_service_url = None
|
self.cas_service_url = None
|
||||||
self.cas_displayname_attribute = None
|
self.cas_displayname_attribute = None
|
||||||
self.cas_required_attributes = {}
|
self.cas_required_attributes = []
|
||||||
|
|
||||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
|
@ -75,3 +84,22 @@ class CasConfig(Config):
|
||||||
# userGroup: "staff"
|
# userGroup: "staff"
|
||||||
# department: None
|
# department: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# CAS uses a legacy required attributes mapping, not the one provided by
|
||||||
|
# SsoAttributeRequirement.
|
||||||
|
REQUIRED_ATTRIBUTES_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parsed_required_attributes_def(
|
||||||
|
required_attributes: Any,
|
||||||
|
) -> List[SsoAttributeRequirement]:
|
||||||
|
validate_config(
|
||||||
|
REQUIRED_ATTRIBUTES_SCHEMA,
|
||||||
|
required_attributes,
|
||||||
|
config_path=("cas_config", "required_attributes"),
|
||||||
|
)
|
||||||
|
return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]
|
||||||
|
|
|
@ -17,8 +17,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import attr
|
from synapse.config.sso import SsoAttributeRequirement
|
||||||
|
|
||||||
from synapse.python_dependencies import DependencyException, check_requirements
|
from synapse.python_dependencies import DependencyException, check_requirements
|
||||||
from synapse.util.module_loader import load_module, load_python_module
|
from synapse.util.module_loader import load_module, load_python_module
|
||||||
|
|
||||||
|
@ -396,32 +395,18 @@ class SAML2Config(Config):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True)
|
|
||||||
class SamlAttributeRequirement:
|
|
||||||
"""Object describing a single requirement for SAML attributes."""
|
|
||||||
|
|
||||||
attribute = attr.ib(type=str)
|
|
||||||
value = attr.ib(type=str)
|
|
||||||
|
|
||||||
JSON_SCHEMA = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
|
|
||||||
"required": ["attribute", "value"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
|
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": SamlAttributeRequirement.JSON_SCHEMA,
|
"items": SsoAttributeRequirement.JSON_SCHEMA,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _parse_attribute_requirements_def(
|
def _parse_attribute_requirements_def(
|
||||||
attribute_requirements: Any,
|
attribute_requirements: Any,
|
||||||
) -> List[SamlAttributeRequirement]:
|
) -> List[SsoAttributeRequirement]:
|
||||||
validate_config(
|
validate_config(
|
||||||
ATTRIBUTE_REQUIREMENTS_SCHEMA,
|
ATTRIBUTE_REQUIREMENTS_SCHEMA,
|
||||||
attribute_requirements,
|
attribute_requirements,
|
||||||
config_path=["saml2_config", "attribute_requirements"],
|
config_path=("saml2_config", "attribute_requirements"),
|
||||||
)
|
)
|
||||||
return [SamlAttributeRequirement(**x) for x in attribute_requirements]
|
return [SsoAttributeRequirement(**x) for x in attribute_requirements]
|
||||||
|
|
|
@ -12,11 +12,28 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True)
|
||||||
|
class SsoAttributeRequirement:
|
||||||
|
"""Object describing a single requirement for SSO attributes."""
|
||||||
|
|
||||||
|
attribute = attr.ib(type=str)
|
||||||
|
# If a value is not given, than the attribute must simply exist.
|
||||||
|
value = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
|
JSON_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
|
||||||
|
"required": ["attribute", "value"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SSOConfig(Config):
|
class SSOConfig(Config):
|
||||||
"""SSO Configuration
|
"""SSO Configuration
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -49,7 +49,7 @@ class CasError(Exception):
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class CasResponse:
|
class CasResponse:
|
||||||
username = attr.ib(type=str)
|
username = attr.ib(type=str)
|
||||||
attributes = attr.ib(type=Dict[str, Optional[str]])
|
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
|
||||||
|
|
||||||
|
|
||||||
class CasHandler:
|
class CasHandler:
|
||||||
|
@ -169,7 +169,7 @@ class CasHandler:
|
||||||
|
|
||||||
# Iterate through the nodes and pull out the user and any extra attributes.
|
# Iterate through the nodes and pull out the user and any extra attributes.
|
||||||
user = None
|
user = None
|
||||||
attributes = {}
|
attributes = {} # type: Dict[str, List[Optional[str]]]
|
||||||
for child in root[0]:
|
for child in root[0]:
|
||||||
if child.tag.endswith("user"):
|
if child.tag.endswith("user"):
|
||||||
user = child.text
|
user = child.text
|
||||||
|
@ -182,7 +182,7 @@ class CasHandler:
|
||||||
tag = attribute.tag
|
tag = attribute.tag
|
||||||
if "}" in tag:
|
if "}" in tag:
|
||||||
tag = tag.split("}")[1]
|
tag = tag.split("}")[1]
|
||||||
attributes[tag] = attribute.text
|
attributes.setdefault(tag, []).append(attribute.text)
|
||||||
|
|
||||||
# Ensure a user was found.
|
# Ensure a user was found.
|
||||||
if user is None:
|
if user is None:
|
||||||
|
@ -303,28 +303,9 @@ class CasHandler:
|
||||||
|
|
||||||
# Ensure that the attributes of the logged in user meet the required
|
# Ensure that the attributes of the logged in user meet the required
|
||||||
# attributes.
|
# attributes.
|
||||||
for required_attribute, required_value in self._cas_required_attributes.items():
|
if not self._sso_handler.check_required_attributes(
|
||||||
# If required attribute was not in CAS Response - Forbidden
|
request, cas_response.attributes, self._cas_required_attributes
|
||||||
if required_attribute not in cas_response.attributes:
|
):
|
||||||
self._sso_handler.render_error(
|
|
||||||
request,
|
|
||||||
"unauthorised",
|
|
||||||
"You are not authorised to log in here.",
|
|
||||||
401,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Also need to check value
|
|
||||||
if required_value is not None:
|
|
||||||
actual_value = cas_response.attributes[required_attribute]
|
|
||||||
# If required attribute value does not match expected - Forbidden
|
|
||||||
if required_value != actual_value:
|
|
||||||
self._sso_handler.render_error(
|
|
||||||
request,
|
|
||||||
"unauthorised",
|
|
||||||
"You are not authorised to log in here.",
|
|
||||||
401,
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Call the mapper to register/login the user
|
# Call the mapper to register/login the user
|
||||||
|
@ -372,9 +353,10 @@ class CasHandler:
|
||||||
if failures:
|
if failures:
|
||||||
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
|
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
|
||||||
|
|
||||||
|
# Arbitrarily use the first attribute found.
|
||||||
display_name = cas_response.attributes.get(
|
display_name = cas_response.attributes.get(
|
||||||
self._cas_displayname_attribute, None
|
self._cas_displayname_attribute, [None]
|
||||||
)
|
)[0]
|
||||||
|
|
||||||
return UserAttributes(localpart=localpart, display_name=display_name)
|
return UserAttributes(localpart=localpart, display_name=display_name)
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ from saml2.client import Saml2Client
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.saml2_config import SamlAttributeRequirement
|
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
|
@ -239,11 +238,9 @@ class SamlHandler(BaseHandler):
|
||||||
|
|
||||||
# Ensure that the attributes of the logged in user meet the required
|
# Ensure that the attributes of the logged in user meet the required
|
||||||
# attributes.
|
# attributes.
|
||||||
for requirement in self._saml2_attribute_requirements:
|
if not self._sso_handler.check_required_attributes(
|
||||||
if not _check_attribute_requirement(saml2_auth.ava, requirement):
|
request, saml2_auth.ava, self._saml2_attribute_requirements
|
||||||
self._sso_handler.render_error(
|
):
|
||||||
request, "unauthorised", "You are not authorised to log in here."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Call the mapper to register/login the user
|
# Call the mapper to register/login the user
|
||||||
|
@ -373,21 +370,6 @@ class SamlHandler(BaseHandler):
|
||||||
del self._outstanding_requests_dict[reqid]
|
del self._outstanding_requests_dict[reqid]
|
||||||
|
|
||||||
|
|
||||||
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
|
|
||||||
values = ava.get(req.attribute, [])
|
|
||||||
for v in values:
|
|
||||||
if v == req.value:
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"SAML2 attribute %s did not match required value '%s' (was '%s')",
|
|
||||||
req.attribute,
|
|
||||||
req.value,
|
|
||||||
values,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
DOT_REPLACE_PATTERN = re.compile(
|
DOT_REPLACE_PATTERN = re.compile(
|
||||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,10 +16,12 @@ import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
||||||
|
from synapse.config.sso import SsoAttributeRequirement
|
||||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||||
from synapse.http import get_request_user_agent
|
from synapse.http import get_request_user_agent
|
||||||
from synapse.http.server import respond_with_html, respond_with_redirect
|
from synapse.http.server import respond_with_html, respond_with_redirect
|
||||||
|
@ -893,6 +896,41 @@ class SsoHandler:
|
||||||
logger.info("Expiring mapping session %s", session_id)
|
logger.info("Expiring mapping session %s", session_id)
|
||||||
del self._username_mapping_sessions[session_id]
|
del self._username_mapping_sessions[session_id]
|
||||||
|
|
||||||
|
def check_required_attributes(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
attributes: Mapping[str, List[Any]],
|
||||||
|
attribute_requirements: Iterable[SsoAttributeRequirement],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Confirm that the required attributes were present in the SSO response.
|
||||||
|
|
||||||
|
If all requirements are met, this will return True.
|
||||||
|
|
||||||
|
If any requirement is not met, then the request will be finalized by
|
||||||
|
showing an error page to the user and False will be returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The request to (potentially) respond to.
|
||||||
|
attributes: The attributes from the SSO IdP.
|
||||||
|
attribute_requirements: The requirements that attributes must meet.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all requirements are met, False if any attribute fails to
|
||||||
|
meet the requirement.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Ensure that the attributes of the logged in user meet the required
|
||||||
|
# attributes.
|
||||||
|
for requirement in attribute_requirements:
|
||||||
|
if not _check_attribute_requirement(attributes, requirement):
|
||||||
|
self.render_error(
|
||||||
|
request, "unauthorised", "You are not authorised to log in here."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
||||||
"""Extract the session ID from the cookie
|
"""Extract the session ID from the cookie
|
||||||
|
@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
||||||
if not session_id:
|
if not session_id:
|
||||||
raise SynapseError(code=400, msg="missing session_id")
|
raise SynapseError(code=400, msg="missing session_id")
|
||||||
return session_id.decode("ascii", errors="replace")
|
return session_id.decode("ascii", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_attribute_requirement(
|
||||||
|
attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
|
||||||
|
) -> bool:
|
||||||
|
"""Check if SSO attributes meet the proper requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attributes: A mapping of attributes to an iterable of one or more values.
|
||||||
|
requirement: The configured requirement to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the required attribute was found and had a proper value.
|
||||||
|
"""
|
||||||
|
if req.attribute not in attributes:
|
||||||
|
logger.info("SSO attribute missing: %s", req.attribute)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If the requirement is None, the attribute existing is enough.
|
||||||
|
if req.value is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
values = attributes[req.attribute]
|
||||||
|
if req.value in values:
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"SSO attribute %s did not match required value '%s' (was '%s')",
|
||||||
|
req.attribute,
|
||||||
|
req.value,
|
||||||
|
values,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
|
@ -16,7 +16,7 @@ from mock import Mock
|
||||||
from synapse.handlers.cas_handler import CasResponse
|
from synapse.handlers.cas_handler import CasResponse
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
from tests.test_utils import simple_async_mock
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
# These are a few constants that are used as config parameters in the tests.
|
# These are a few constants that are used as config parameters in the tests.
|
||||||
BASE_URL = "https://synapse/"
|
BASE_URL = "https://synapse/"
|
||||||
|
@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
"server_url": SERVER_URL,
|
"server_url": SERVER_URL,
|
||||||
"service_url": BASE_URL,
|
"service_url": BASE_URL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Update this config with what's in the default config so that
|
||||||
|
# override_config works as expected.
|
||||||
|
cas_config.update(config.get("cas_config", {}))
|
||||||
config["cas_config"] = cas_config
|
config["cas_config"] = cas_config
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
|
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"cas_config": {
|
||||||
|
"required_attributes": {"userGroup": "staff", "department": None}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_required_attributes(self):
|
||||||
|
"""The required attributes must be met from the CAS response."""
|
||||||
|
|
||||||
|
# stub out the auth handler
|
||||||
|
auth_handler = self.hs.get_auth_handler()
|
||||||
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
|
# The response doesn't have the proper userGroup or department.
|
||||||
|
cas_response = CasResponse("test_user", {})
|
||||||
|
request = _mock_request()
|
||||||
|
self.get_success(
|
||||||
|
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||||
|
)
|
||||||
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
|
|
||||||
|
# The response doesn't have any department.
|
||||||
|
cas_response = CasResponse("test_user", {"userGroup": "staff"})
|
||||||
|
request.reset_mock()
|
||||||
|
self.get_success(
|
||||||
|
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||||
|
)
|
||||||
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
|
|
||||||
|
# Add the proper attributes and it should succeed.
|
||||||
|
cas_response = CasResponse(
|
||||||
|
"test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
|
||||||
|
)
|
||||||
|
request.reset_mock()
|
||||||
|
self.get_success(
|
||||||
|
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the auth handler got called as expected
|
||||||
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
"@test_user:test", request, "redirect_uri", None, new_user=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _mock_request():
|
def _mock_request():
|
||||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||||
return Mock(spec=["getClientIP", "getHeader"])
|
return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
|
||||||
|
|
|
@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
|
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"saml2_config": {
|
||||||
|
"attribute_requirements": [
|
||||||
|
{"attribute": "userGroup", "value": "staff"},
|
||||||
|
{"attribute": "department", "value": "sales"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_attribute_requirements(self):
|
||||||
|
"""The required attributes must be met from the SAML response."""
|
||||||
|
|
||||||
|
# stub out the auth handler
|
||||||
|
auth_handler = self.hs.get_auth_handler()
|
||||||
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
|
# The response doesn't have the proper userGroup or department.
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||||
|
request = _mock_request()
|
||||||
|
self.get_success(
|
||||||
|
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
|
||||||
|
)
|
||||||
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
|
|
||||||
|
# The response doesn't have the proper department.
|
||||||
|
saml_response = FakeAuthnResponse(
|
||||||
|
{"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
|
||||||
|
)
|
||||||
|
request = _mock_request()
|
||||||
|
self.get_success(
|
||||||
|
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
|
||||||
|
)
|
||||||
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
|
|
||||||
|
# Add the proper attributes and it should succeed.
|
||||||
|
saml_response = FakeAuthnResponse(
|
||||||
|
{
|
||||||
|
"uid": "test_user",
|
||||||
|
"username": "test_user",
|
||||||
|
"userGroup": ["staff", "admin"],
|
||||||
|
"department": ["sales"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
request.reset_mock()
|
||||||
|
self.get_success(
|
||||||
|
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the auth handler got called as expected
|
||||||
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
"@test_user:test", request, "redirect_uri", None, new_user=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _mock_request():
|
def _mock_request():
|
||||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||||
return Mock(spec=["getClientIP", "getHeader"])
|
return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
|
||||||
|
|
Loading…
Reference in New Issue