Combine the CAS & SAML implementations for required attributes. (#9326)

This commit is contained in:
Patrick Cloke 2021-02-11 10:05:15 -05:00 committed by GitHub
parent 80d6dc9783
commit 6dade80048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 245 additions and 77 deletions

1
changelog.d/9326.misc Normal file
View File

@ -0,0 +1 @@
Share the code for handling required attributes between the CAS and SAML handlers.

View File

@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List
from synapse.config.sso import SsoAttributeRequirement
from ._base import Config from ._base import Config
from ._util import validate_config
class CasConfig(Config): class CasConfig(Config):
@ -38,12 +43,16 @@ class CasConfig(Config):
public_base_url + "_matrix/client/r0/login/cas/ticket" public_base_url + "_matrix/client/r0/login/cas/ticket"
) )
self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes") or {} required_attributes = cas_config.get("required_attributes") or {}
self.cas_required_attributes = _parsed_required_attributes_def(
required_attributes
)
else: else:
self.cas_server_url = None self.cas_server_url = None
self.cas_service_url = None self.cas_service_url = None
self.cas_displayname_attribute = None self.cas_displayname_attribute = None
self.cas_required_attributes = {} self.cas_required_attributes = []
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\
@ -75,3 +84,22 @@ class CasConfig(Config):
# userGroup: "staff" # userGroup: "staff"
# department: None # department: None
""" """
# CAS uses a legacy required attributes mapping, not the one provided by
# SsoAttributeRequirement.
REQUIRED_ATTRIBUTES_SCHEMA = {
"type": "object",
"additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
}
def _parsed_required_attributes_def(
required_attributes: Any,
) -> List[SsoAttributeRequirement]:
validate_config(
REQUIRED_ATTRIBUTES_SCHEMA,
required_attributes,
config_path=("cas_config", "required_attributes"),
)
return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]

View File

@ -17,8 +17,7 @@
import logging import logging
from typing import Any, List from typing import Any, List
import attr from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module from synapse.util.module_loader import load_module, load_python_module
@ -396,32 +395,18 @@ class SAML2Config(Config):
} }
@attr.s(frozen=True)
class SamlAttributeRequirement:
"""Object describing a single requirement for SAML attributes."""
attribute = attr.ib(type=str)
value = attr.ib(type=str)
JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}
ATTRIBUTE_REQUIREMENTS_SCHEMA = { ATTRIBUTE_REQUIREMENTS_SCHEMA = {
"type": "array", "type": "array",
"items": SamlAttributeRequirement.JSON_SCHEMA, "items": SsoAttributeRequirement.JSON_SCHEMA,
} }
def _parse_attribute_requirements_def( def _parse_attribute_requirements_def(
attribute_requirements: Any, attribute_requirements: Any,
) -> List[SamlAttributeRequirement]: ) -> List[SsoAttributeRequirement]:
validate_config( validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA, ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements, attribute_requirements,
config_path=["saml2_config", "attribute_requirements"], config_path=("saml2_config", "attribute_requirements"),
) )
return [SamlAttributeRequirement(**x) for x in attribute_requirements] return [SsoAttributeRequirement(**x) for x in attribute_requirements]

View File

@ -12,11 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Any, Dict, Optional
import attr
from ._base import Config from ._base import Config
@attr.s(frozen=True)
class SsoAttributeRequirement:
"""Object describing a single requirement for SSO attributes."""
attribute = attr.ib(type=str)
# If a value is not given, than the attribute must simply exist.
value = attr.ib(type=Optional[str])
JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}
class SSOConfig(Config): class SSOConfig(Config):
"""SSO Configuration """SSO Configuration
""" """

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
import attr import attr
@ -49,7 +49,7 @@ class CasError(Exception):
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class CasResponse: class CasResponse:
username = attr.ib(type=str) username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]]) attributes = attr.ib(type=Dict[str, List[Optional[str]]])
class CasHandler: class CasHandler:
@ -169,7 +169,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes. # Iterate through the nodes and pull out the user and any extra attributes.
user = None user = None
attributes = {} attributes = {} # type: Dict[str, List[Optional[str]]]
for child in root[0]: for child in root[0]:
if child.tag.endswith("user"): if child.tag.endswith("user"):
user = child.text user = child.text
@ -182,7 +182,7 @@ class CasHandler:
tag = attribute.tag tag = attribute.tag
if "}" in tag: if "}" in tag:
tag = tag.split("}")[1] tag = tag.split("}")[1]
attributes[tag] = attribute.text attributes.setdefault(tag, []).append(attribute.text)
# Ensure a user was found. # Ensure a user was found.
if user is None: if user is None:
@ -303,28 +303,9 @@ class CasHandler:
# Ensure that the attributes of the logged in user meet the required # Ensure that the attributes of the logged in user meet the required
# attributes. # attributes.
for required_attribute, required_value in self._cas_required_attributes.items(): if not self._sso_handler.check_required_attributes(
# If required attribute was not in CAS Response - Forbidden request, cas_response.attributes, self._cas_required_attributes
if required_attribute not in cas_response.attributes: ):
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return return
# Call the mapper to register/login the user # Call the mapper to register/login the user
@ -372,9 +353,10 @@ class CasHandler:
if failures: if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
# Arbitrarily use the first attribute found.
display_name = cas_response.attributes.get( display_name = cas_response.attributes.get(
self._cas_displayname_attribute, None self._cas_displayname_attribute, [None]
) )[0]
return UserAttributes(localpart=localpart, display_name=display_name) return UserAttributes(localpart=localpart, display_name=display_name)

View File

@ -23,7 +23,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
@ -239,11 +238,9 @@ class SamlHandler(BaseHandler):
# Ensure that the attributes of the logged in user meet the required # Ensure that the attributes of the logged in user meet the required
# attributes. # attributes.
for requirement in self._saml2_attribute_requirements: if not self._sso_handler.check_required_attributes(
if not _check_attribute_requirement(saml2_auth.ava, requirement): request, saml2_auth.ava, self._saml2_attribute_requirements
self._sso_handler.render_error( ):
request, "unauthorised", "You are not authorised to log in here."
)
return return
# Call the mapper to register/login the user # Call the mapper to register/login the user
@ -373,21 +370,6 @@ class SamlHandler(BaseHandler):
del self._outstanding_requests_dict[reqid] del self._outstanding_requests_dict[reqid]
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return True
logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
return False
DOT_REPLACE_PATTERN = re.compile( DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
) )

View File

@ -16,10 +16,12 @@ import abc
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
List,
Mapping, Mapping,
Optional, Optional,
Set, Set,
@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.server import respond_with_html, respond_with_redirect
@ -893,6 +896,41 @@ class SsoHandler:
logger.info("Expiring mapping session %s", session_id) logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id] del self._username_mapping_sessions[session_id]
def check_required_attributes(
self,
request: SynapseRequest,
attributes: Mapping[str, List[Any]],
attribute_requirements: Iterable[SsoAttributeRequirement],
) -> bool:
"""
Confirm that the required attributes were present in the SSO response.
If all requirements are met, this will return True.
If any requirement is not met, then the request will be finalized by
showing an error page to the user and False will be returned.
Args:
request: The request to (potentially) respond to.
attributes: The attributes from the SSO IdP.
attribute_requirements: The requirements that attributes must meet.
Returns:
True if all requirements are met, False if any attribute fails to
meet the requirement.
"""
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in attribute_requirements:
if not _check_attribute_requirement(attributes, requirement):
self.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return False
return True
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie """Extract the session ID from the cookie
@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
if not session_id: if not session_id:
raise SynapseError(code=400, msg="missing session_id") raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace") return session_id.decode("ascii", errors="replace")
def _check_attribute_requirement(
attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
) -> bool:
"""Check if SSO attributes meet the proper requirements.
Args:
attributes: A mapping of attributes to an iterable of one or more values.
requirement: The configured requirement to check.
Returns:
True if the required attribute was found and had a proper value.
"""
if req.attribute not in attributes:
logger.info("SSO attribute missing: %s", req.attribute)
return False
# If the requirement is None, the attribute existing is enough.
if req.value is None:
return True
values = attributes[req.attribute]
if req.value in values:
return True
logger.info(
"SSO attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
return False

View File

@ -16,7 +16,7 @@ from mock import Mock
from synapse.handlers.cas_handler import CasResponse from synapse.handlers.cas_handler import CasResponse
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests. # These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/" BASE_URL = "https://synapse/"
@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
"server_url": SERVER_URL, "server_url": SERVER_URL,
"service_url": BASE_URL, "service_url": BASE_URL,
} }
# Update this config with what's in the default config so that
# override_config works as expected.
cas_config.update(config.get("cas_config", {}))
config["cas_config"] = cas_config config["cas_config"] = cas_config
return config return config
@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase):
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
) )
@override_config(
{
"cas_config": {
"required_attributes": {"userGroup": "staff", "department": None}
}
}
)
def test_required_attributes(self):
"""The required attributes must be met from the CAS response."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
request = _mock_request()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have any department.
cas_response = CasResponse("test_user", {"userGroup": "staff"})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_not_called()
# Add the proper attributes and it should succeed.
cas_response = CasResponse(
"test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
)
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
)
def _mock_request(): def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest""" """Returns a mock which will stand in as a SynapseRequest"""
return Mock(spec=["getClientIP", "getHeader"]) return Mock(spec=["getClientIP", "getHeader", "_disconnected"])

View File

@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
) )
self.assertEqual(e.value.location, b"https://custom-saml-redirect/") self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
@override_config(
{
"saml2_config": {
"attribute_requirements": [
{"attribute": "userGroup", "value": "staff"},
{"attribute": "department", "value": "sales"},
],
},
}
)
def test_attribute_requirements(self):
"""The required attributes must be met from the SAML response."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have the proper department.
saml_response = FakeAuthnResponse(
{"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)
auth_handler.complete_sso_login.assert_not_called()
# Add the proper attributes and it should succeed.
saml_response = FakeAuthnResponse(
{
"uid": "test_user",
"username": "test_user",
"userGroup": ["staff", "admin"],
"department": ["sales"],
}
)
request.reset_mock()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
)
def _mock_request(): def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest""" """Returns a mock which will stand in as a SynapseRequest"""
return Mock(spec=["getClientIP", "getHeader"]) return Mock(spec=["getClientIP", "getHeader", "_disconnected"])