Combine the CAS & SAML implementations for required attributes. (#9326)

This commit is contained in:
Patrick Cloke 2021-02-11 10:05:15 -05:00 committed by GitHub
parent 80d6dc9783
commit 6dade80048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 245 additions and 77 deletions

1
changelog.d/9326.misc Normal file
View File

@ -0,0 +1 @@
Share the code for handling required attributes between the CAS and SAML handlers.

View File

@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from synapse.config.sso import SsoAttributeRequirement
from ._base import Config
from ._util import validate_config
class CasConfig(Config):
@ -38,12 +43,16 @@ class CasConfig(Config):
public_base_url + "_matrix/client/r0/login/cas/ticket"
)
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes") or {}
required_attributes = cas_config.get("required_attributes") or {}
self.cas_required_attributes = _parsed_required_attributes_def(
required_attributes
)
else:
self.cas_server_url = None
self.cas_service_url = None
self.cas_displayname_attribute = None
self.cas_required_attributes = {}
self.cas_required_attributes = []
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
@ -75,3 +84,22 @@ class CasConfig(Config):
# userGroup: "staff"
# department: None
"""
# CAS uses a legacy required attributes mapping, not the one provided by
# SsoAttributeRequirement.
REQUIRED_ATTRIBUTES_SCHEMA = {
"type": "object",
"additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
}
def _parsed_required_attributes_def(
required_attributes: Any,
) -> List[SsoAttributeRequirement]:
validate_config(
REQUIRED_ATTRIBUTES_SCHEMA,
required_attributes,
config_path=("cas_config", "required_attributes"),
)
return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]

View File

@ -17,8 +17,7 @@
import logging
from typing import Any, List
import attr
from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@ -396,32 +395,18 @@ class SAML2Config(Config):
}
@attr.s(frozen=True)
class SamlAttributeRequirement:
"""Object describing a single requirement for SAML attributes."""
attribute = attr.ib(type=str)
value = attr.ib(type=str)
JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
"type": "array",
"items": SamlAttributeRequirement.JSON_SCHEMA,
"items": SsoAttributeRequirement.JSON_SCHEMA,
}
def _parse_attribute_requirements_def(
attribute_requirements: Any,
) -> List[SamlAttributeRequirement]:
) -> List[SsoAttributeRequirement]:
validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements,
config_path=["saml2_config", "attribute_requirements"],
config_path=("saml2_config", "attribute_requirements"),
)
return [SamlAttributeRequirement(**x) for x in attribute_requirements]
return [SsoAttributeRequirement(**x) for x in attribute_requirements]

View File

@ -12,11 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from typing import Any, Dict, Optional
import attr
from ._base import Config
@attr.s(frozen=True)
class SsoAttributeRequirement:
"""Object describing a single requirement for SSO attributes."""
attribute = attr.ib(type=str)
# If a value is not given, than the attribute must simply exist.
value = attr.ib(type=Optional[str])
JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}
class SSOConfig(Config):
"""SSO Configuration
"""

View File

@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET
import attr
@ -49,7 +49,7 @@ class CasError(Exception):
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
class CasHandler:
@ -169,7 +169,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
attributes = {} # type: Dict[str, List[Optional[str]]]
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@ -182,7 +182,7 @@ class CasHandler:
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
attributes.setdefault(tag, []).append(attribute.text)
# Ensure a user was found.
if user is None:
@ -303,28 +303,9 @@ class CasHandler:
# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
return
# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
401,
)
if not self._sso_handler.check_required_attributes(
request, cas_response.attributes, self._cas_required_attributes
):
return
# Call the mapper to register/login the user
@ -372,9 +353,10 @@ class CasHandler:
if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
# Arbitrarily use the first attribute found.
display_name = cas_response.attributes.get(
self._cas_displayname_attribute, None
)
self._cas_displayname_attribute, [None]
)[0]
return UserAttributes(localpart=localpart, display_name=display_name)

View File

@ -23,7 +23,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
@ -239,11 +238,9 @@ class SamlHandler(BaseHandler):
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
if not self._sso_handler.check_required_attributes(
request, saml2_auth.ava, self._saml2_attribute_requirements
):
return
# Call the mapper to register/login the user
@ -373,21 +370,6 @@ class SamlHandler(BaseHandler):
del self._outstanding_requests_dict[reqid]
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return True
logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
return False
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)

View File

@ -16,10 +16,12 @@ import abc
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@ -893,6 +896,41 @@ class SsoHandler:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
def check_required_attributes(
self,
request: SynapseRequest,
attributes: Mapping[str, List[Any]],
attribute_requirements: Iterable[SsoAttributeRequirement],
) -> bool:
"""
Confirm that the required attributes were present in the SSO response.
If all requirements are met, this will return True.
If any requirement is not met, then the request will be finalized by
showing an error page to the user and False will be returned.
Args:
request: The request to (potentially) respond to.
attributes: The attributes from the SSO IdP.
attribute_requirements: The requirements that attributes must meet.
Returns:
True if all requirements are met, False if any attribute fails to
meet the requirement.
"""
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in attribute_requirements:
if not _check_attribute_requirement(attributes, requirement):
self.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return False
return True
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")
def _check_attribute_requirement(
attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
) -> bool:
"""Check if SSO attributes meet the proper requirements.
Args:
attributes: A mapping of attributes to an iterable of one or more values.
requirement: The configured requirement to check.
Returns:
True if the required attribute was found and had a proper value.
"""
if req.attribute not in attributes:
logger.info("SSO attribute missing: %s", req.attribute)
return False
# If the requirement is None, the attribute existing is enough.
if req.value is None:
return True
values = attributes[req.attribute]
if req.value in values:
return True
logger.info(
"SSO attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
return False

View File

@ -16,7 +16,7 @@ from mock import Mock
from synapse.handlers.cas_handler import CasResponse
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
"server_url": SERVER_URL,
"service_url": BASE_URL,
}
# Update this config with what's in the default config so that
# override_config works as expected.
cas_config.update(config.get("cas_config", {}))
config["cas_config"] = cas_config
return config
@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase):
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
)
@override_config(
{
"cas_config": {
"required_attributes": {"userGroup": "staff", "department": None}
}
}
)
def test_required_attributes(self):
"""The required attributes must be met from the CAS response."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
request = _mock_request()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have any department.
cas_response = CasResponse("test_user", {"userGroup": "staff"})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_not_called()
# Add the proper attributes and it should succeed.
cas_response = CasResponse(
"test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
)
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
)
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
return Mock(spec=["getClientIP", "getHeader"])
return Mock(spec=["getClientIP", "getHeader", "_disconnected"])

View File

@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
@override_config(
{
"saml2_config": {
"attribute_requirements": [
{"attribute": "userGroup", "value": "staff"},
{"attribute": "department", "value": "sales"},
],
},
}
)
def test_attribute_requirements(self):
"""The required attributes must be met from the SAML response."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have the proper department.
saml_response = FakeAuthnResponse(
{"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)
auth_handler.complete_sso_login.assert_not_called()
# Add the proper attributes and it should succeed.
saml_response = FakeAuthnResponse(
{
"uid": "test_user",
"username": "test_user",
"userGroup": ["staff", "admin"],
"department": ["sales"],
}
)
request.reset_mock()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
)
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
return Mock(spec=["getClientIP", "getHeader"])
return Mock(spec=["getClientIP", "getHeader", "_disconnected"])