Merge pull request #6069 from matrix-org/rav/fix_attribute_mapping

Fix a bug with saml attribute maps.
This commit is contained in:
Richard van der Hoff 2019-09-24 15:07:26 +01:00 committed by GitHub
commit bb82be9851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 7 deletions

1
changelog.d/6069.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug which caused SAML attribute maps to be overridden by defaults.

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,11 +13,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_python_module
from ._base import Config, ConfigError from ._base import Config, ConfigError
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`:
* For keys where both `merge_dict` and `into_dict` have a dict value, the values
are recursively merged
* For all other keys, the values in `into_dict` (if any) are overwritten with
the value from `merge_dict`.
Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
"""
for k, v in merge_dict.items():
if k not in into_dict:
into_dict[k] = v
continue
current_val = into_dict[k]
if isinstance(v, dict) and isinstance(current_val, dict):
_dict_merge(v, current_val)
continue
# otherwise we just overwrite
into_dict[k] = v
class SAML2Config(Config): class SAML2Config(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.saml2_enabled = False self.saml2_enabled = False
@ -36,15 +67,20 @@ class SAML2Config(Config):
self.saml2_enabled = True self.saml2_enabled = True
import saml2.config saml2_config_dict = self._default_saml_config_dict()
_dict_merge(
self.saml2_sp_config = saml2.config.SPConfig() merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
self.saml2_sp_config.load(self._default_saml_config_dict()) )
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
config_path = saml2_config.get("config_path", None) config_path = saml2_config.get("config_path", None)
if config_path is not None: if config_path is not None:
self.saml2_sp_config.load_file(config_path) mod = load_python_module(config_path)
_dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
import saml2.config
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(saml2_config_dict)
# session lifetime: in milliseconds # session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration( self.saml2_session_lifetime = self.parse_duration(

View File

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import importlib.util
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
def load_module(provider): def load_module(provider):
""" Loads a module with its config """ Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config' Take a dict with keys 'module' (the module name) and 'config'
(the config dict). (the config dict).
@ -38,3 +39,20 @@ def load_module(provider):
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config return provider_class, provider_config
def load_python_module(location: str):
"""Load a python module, and return a reference to its global namespace
Args:
location (str): path to the module
Returns:
python module object
"""
spec = importlib.util.spec_from_file_location(location, location)
if spec is None:
raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod