Pass the module_api into the SamlMappingProvider

... for consistency with other modules, and because we'll need it sooner or
later and it will be a pain to introduce later.
This commit is contained in:
Richard van der Hoff 2020-01-11 11:48:43 +00:00
parent d2906fe666
commit 47e63cc67a
1 changed files with 5 additions and 2 deletions

View File

@ -24,6 +24,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
from synapse.rest.client.v1.login import SSOAuthHandler from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import ( from synapse.types import (
UserID, UserID,
@ -59,7 +60,8 @@ class SamlHandler:
# plugin to do custom mapping from saml response to mxid # plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
hs.config.saml2_user_mapping_provider_config hs.config.saml2_user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
) )
# identifier for the external_ids table # identifier for the external_ids table
@ -265,11 +267,12 @@ class SamlConfig(object):
class DefaultSamlMappingProvider(object): class DefaultSamlMappingProvider(object):
__version__ = "0.0.1" __version__ = "0.0.1"
def __init__(self, parsed_config: SamlConfig): def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
"""The default SAML user mapping provider """The default SAML user mapping provider
Args: Args:
parsed_config: Module configuration parsed_config: Module configuration
module_api: module api proxy
""" """
self._mxid_source_attribute = parsed_config.mxid_source_attribute self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper self._mxid_mapper = parsed_config.mxid_mapper